【发布时间】:2019-07-07 19:08:41
【问题描述】:
我使用生成器创建随机批次,例如:
import torch
n = 10
batch_size = 2
x = torch.zeros((batch_size, n), dtype=torch.float)
in_flags = torch.randint(n, (batch_size,), dtype=torch.long)
for idx, row in enumerate(x):
row[in_flags[idx]] = 1.0
但这样做的缺点是循环在 Python 中运行。 这就是嵌入的本义(不要与 PyTorch nn.embedding 混淆)。是否可以使用一个 PyTorch 运算符使其在本机或 GPU 中执行?
【问题讨论】:
标签: python-3.x pytorch embedding