【问题标题】:How to do one line flag embed in PyTorch (not nn.Embedding)?如何在 PyTorch 中嵌入一行标志(不是 nn.Embedding)?
【发布时间】:2019-07-07 19:08:41
【问题描述】:

我使用生成器创建随机批次,例如:

import torch

n = 10
batch_size = 2

x = torch.zeros((batch_size, n), dtype=torch.float)
in_flags = torch.randint(n, (batch_size,), dtype=torch.long)

for idx, row in enumerate(x):
    row[in_flags[idx]] = 1.0

但这样做的缺点是循环在 Python 中运行。 这就是嵌入的本义(不要与 PyTorch nn.embedding 混淆)。是否可以使用一个 PyTorch 运算符使其在本机或 GPU 中执行?

【问题讨论】:

    标签: python-3.x pytorch embedding


    【解决方案1】:

    你可以这样做:

    import torch
    
    n = 10
    batch_size = 2
    
    in_flags = torch.randint(n, (batch_size,), dtype=torch.long)
    x = torch.zeros((batch_size, n), dtype=torch.float)
    
    # this is how you can do this
    x[torch.arange(batch_size), in_flags] = 1.0
    
    print(in_flags)
    print(x)
    

    输出:

    tensor([8, 0])
    tensor([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
            [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
    

    【讨论】:

      猜你喜欢
      • 2018-10-24
      • 1970-01-01
      • 2019-11-23
      • 2021-02-07
      • 1970-01-01
      • 2021-03-18
      • 2018-11-17
      • 2021-03-19
      • 2019-08-20
      相关资源
      最近更新 更多