【问题标题】:Pytorch - how do I index a 2-D matrix row wise?Pytorch - 如何按行索引二维矩阵?
【发布时间】:2019-11-17 02:29:14
【问题描述】:

我想按行索引二维矩阵并重新分配值。

例如,首先考虑一个一维向量情况,其中我们有三个形状相同的一维张量t1, indexes, t2。我们可以按如下方式进行索引和重新分配:

indexes = torch.tensor([0, 2, 1, 3])
t1 = torch.tensor([0.0, 0.0, 0.0, 0.0])
t2 = torch.tensor([0.1, 0.2, 0.3, 0.4])

t1[indexes] = t2

现在,假设t1, indexes, t2 是二维矩阵而不是一维向量,并且具有相同的形状(R X C)。我想对这些矩阵中的每一行进行与上述类似的索引,其中:

for i in range(R):
    t1[i][indexes[i]] = t2[i]

我想将此操作向量化,而不是使用 for 循环。我该怎么做?

【问题讨论】:

标签: pytorch


【解决方案1】:

因此,为了进行multi-index 选择,您可以使用torch.gather 函数,该函数沿由dim(第二个参数)指定的轴收集值。

示例 1:

t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4], 
                   [0.8, 1.8, 0.2, 0.3], 
                   [0.5, 0.1, 0.2, 0.4]])
indexes1 = torch.tensor([[0, 2, 0, 2], 
                         [0, 1, 1, 0], 
                         [0, 0, 1, 2]])
t1 = torch.gather(t2, 0, indexes1) # dim is 0
print(t1)

输出:

tensor([[0.1000, 0.1000, 0.3000, 0.4000],
        [0.1000, 1.8000, 0.2000, 0.4000],
        [0.1000, 0.2000, 0.2000, 0.4000]])

示例 2:

t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4], 
                   [0.8, 1.8, 0.2, 0.3], 
                   [0.5, 0.1, 0.2, 0.4]])
indexes2 = torch.tensor([[0, 3, 2, 0], 
                         [0, 1, 1, 3], 
                         [0, 0, 3, 2]])  
t1 = torch.gather(t2, 1, indexes2) # dim is 1
print(t1)

输出:

tensor([[0.1000, 0.4000, 0.3000, 0.1000],
        [0.8000, 1.8000, 1.8000, 0.3000],
        [0.5000, 0.5000, 0.4000, 0.2000]])

要了解更多关于torch.gather 功能的信息,请查看this SO 讨论。

您也可以使用torch.Tensor.scatter_ 来做同样的事情。

t1.scatter_(0, indexes, t2) 基本上说将t2 张量的元素发送到t1 张量中的以下索引(在indexes 张量中指定),逐行(dim 0)。

示例:

t1 = torch.zeros((3, 4))
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4], 
                   [0.8, 1.8, 0.2, 0.3], 
                   [0.5, 0.1, 0.2, 0.4]])
indexes = torch.tensor([[1, 2, 0, 2], 
                        [0, 1, 2, 1], 
                        [2, 0, 1, 0]])
t1 = t1.scatter_(0, indexes, t2)
print(t1)

输出:

tensor([[0.8000, 0.1000, 0.3000, 0.4000],
        [0.1000, 1.8000, 0.2000, 0.3000],
        [0.5000, 0.2000, 0.2000, 0.4000]])

您可以从here 了解更多信息。

【讨论】:

  • 分配给收集的条目怎么样?这就是问题的重点。
【解决方案2】:

与@Anubhav 的答案类似,scatter_ 的尺寸略有变化,这完成了工作。来源:PyTorch Discussion

indexes = torch.tensor([[0, 2, 1, 3],
                        [1, 0, 3, 2]])
t1 = torch.zeros_like(indexes).float()
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4],
                   [0.5, 0.6, 0.7, 0.8]])
t1.scatter_(1, indexes, t2)

【讨论】:

  • 我在上面的解决方案中也讨论了相同的技术。因此,您可以评论说第二种解决方案有效,而不是使用相同的技术添加另一个答案。谢谢。
  • 感谢@Anubhav 回答我的问题,但我发布另一个解决方案但不接受您的解决方案的原因是scatter_ 的维度不同,以达到预期的结果。我已经修改了我的答案以参考您的答案:) 再次感谢您的回答!
猜你喜欢
  • 1970-01-01
  • 2019-09-15
  • 2021-10-30
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2018-11-23
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多