因此,为了进行multi-index 选择,您可以使用torch.gather 函数,该函数沿由dim(第二个参数)指定的轴收集值。
示例 1:
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4],
[0.8, 1.8, 0.2, 0.3],
[0.5, 0.1, 0.2, 0.4]])
indexes1 = torch.tensor([[0, 2, 0, 2],
[0, 1, 1, 0],
[0, 0, 1, 2]])
t1 = torch.gather(t2, 0, indexes1) # dim is 0
print(t1)
输出:
tensor([[0.1000, 0.1000, 0.3000, 0.4000],
[0.1000, 1.8000, 0.2000, 0.4000],
[0.1000, 0.2000, 0.2000, 0.4000]])
示例 2:
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4],
[0.8, 1.8, 0.2, 0.3],
[0.5, 0.1, 0.2, 0.4]])
indexes2 = torch.tensor([[0, 3, 2, 0],
[0, 1, 1, 3],
[0, 0, 3, 2]])
t1 = torch.gather(t2, 1, indexes2) # dim is 1
print(t1)
输出:
tensor([[0.1000, 0.4000, 0.3000, 0.1000],
[0.8000, 1.8000, 1.8000, 0.3000],
[0.5000, 0.5000, 0.4000, 0.2000]])
要了解更多关于torch.gather 功能的信息,请查看this SO 讨论。
您也可以使用torch.Tensor.scatter_ 来做同样的事情。
t1.scatter_(0, indexes, t2) 基本上说将t2 张量的元素发送到t1 张量中的以下索引(在indexes 张量中指定),逐行(dim 0)。
示例:
t1 = torch.zeros((3, 4))
t2 = torch.tensor([[0.1, 0.2, 0.3, 0.4],
[0.8, 1.8, 0.2, 0.3],
[0.5, 0.1, 0.2, 0.4]])
indexes = torch.tensor([[1, 2, 0, 2],
[0, 1, 2, 1],
[2, 0, 1, 0]])
t1 = t1.scatter_(0, indexes, t2)
print(t1)
输出:
tensor([[0.8000, 0.1000, 0.3000, 0.4000],
[0.1000, 1.8000, 0.2000, 0.3000],
[0.5000, 0.2000, 0.2000, 0.4000]])
您可以从here 了解更多信息。