首先,让我快速了解一下使用 numpy 数组和另一个张量对张量进行索引的想法。
示例:这是我们要索引的目标张量
numpy_indices = torch.tensor([[0, 1, 2, 7],
[0, 1, 2, 3]]) # numpy array
tensor_indices = torch.tensor([[0, 1, 2, 7],
[0, 1, 2, 3]]) # 2D tensor
t = torch.tensor([[1, 2, 3, 4], # targeted tensor
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24],
[25, 26, 27, 28],
[29, 30, 31, 32]])
numpy_result = t[numpy_indices]
tensor_result = t[tensor_indices]
-
使用 2D numpy 数组进行索引:索引的读取方式类似于 (x,y) tensor[row,column] 对,例如t[0,0], t[1,1], t[2,2], and t[7,3].
print(numpy_result) # tensor([ 1, 6, 11, 32])
-
使用 2D 张量进行索引:以逐行方式遍历索引张量,每个值都是目标张量中一行的索引。
例如[ [t[0],t[1],t[2],[7]] , [[0],[1],[2],[3]] ]见下例,tensor_result索引后的新形状为(tensor_indices.shape[0],tensor_indices.shape[1],t.shape[1])=(2,4,4)。
print(tensor_result) # tensor([[[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12],
# [29, 30, 31, 32]],
# [[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12],
# [ 13, 14, 15, 16]]])
如果您尝试在numpy_indices 中添加第三行,您将遇到同样的错误,因为索引将由 3D 表示,例如 (0,0,0)...(7,3,3 )。
indices = np.array([[0, 1, 2, 7],
[0, 1, 2, 3],
[0, 1, 2, 3]])
print(numpy_result) # IndexError: too many indices for tensor of dimension 2
但是,张量索引不是这种情况,形状会更大(3,4,4)。
最后,如您所见,两种索引的输出完全不同。要解决您的问题,您可以使用
xx = torch.tensor(xx).long() # convert a numpy array to a tensor
在高级索引(numpy_indices > 3 的行)的情况下会发生什么,因为您的情况仍然模棱两可且未解决,您可以检查1、2、3。