以下代码与循环中的代码类似。不同之处在于,我们不是按顺序索引数组Z、X 和Y,而是使用数组i 并行索引它们
B, M, N = 13, 7, 19
X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M , size= [B,1])
Z = np.random.randint(100, size= [B,N])
i = np.arange(B)
Y = Y.ravel() # reducing array to rank-1, for easy indexing
Z[i] = X[i,Y[i],:]
这段代码可以进一步简化为
-> Z[i] = X[i,Y[i],:]
-> Z[i] = X[i,Y[i]]
-> Z[i] = X[i,Y]
-> Z = X[i,Y]
pytorch 等效代码
B, M, N = 5, 7, 3
X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M , size= [B,1])
Z = torch.randint(100, size= [B,N])
i = torch.arange(B)
Y = Y.ravel()
Z = X[i,Y]