【问题标题】:batched tensor slice, slice B x N x M with B x 1批量张量切片,切片 B x N x M 和 B x 1
【发布时间】:2021-09-25 12:50:12
【问题描述】:

我有一个 B x M x N 张量 X,我有一个 B x 1 张量 Y,它对应于我想要保留的维度 = 1 处的张量 X 的索引。为了避免循环,这个切片的简写是什么?

基本上我想这样做:

Z = torch.zeros(B,N)

for i in range(B):
    Z[i] = X[i][Y[i]]

【问题讨论】:

    标签: numpy pytorch slice tensor numpy-slicing


    【解决方案1】:

    以下代码与循环中的代码类似。不同之处在于,我们不是按顺序索引数组ZXY,而是使用数组i 并行索引它们

    B, M, N = 13, 7, 19
    
    X = np.random.randint(100, size= [B,M,N])
    Y = np.random.randint(M  , size= [B,1])
    Z = np.random.randint(100, size= [B,N])
    
    i = np.arange(B)
    Y = Y.ravel()    # reducing array to rank-1, for easy indexing
    
    Z[i] = X[i,Y[i],:]
    

    这段代码可以进一步简化为

    -> Z[i] = X[i,Y[i],:]
    -> Z[i] = X[i,Y[i]]
    -> Z[i] = X[i,Y]
    -> Z    = X[i,Y]
    

    pytorch 等效代码

    B, M, N = 5, 7, 3
    
    X = torch.randint(100, size= [B,M,N])
    Y = torch.randint(M  , size= [B,1])
    Z = torch.randint(100, size= [B,N])
    
    i = torch.arange(B)
    Y = Y.ravel()
    
    Z = X[i,Y]
    

    【讨论】:

    • 他们说Y 的形状是(B, 1),所以你可能想在最后一个表达式中更改为Y.view(-1) 或类似的东西。
    【解决方案2】:

    @Hammad 提供的答案简短而完美。如果您有兴趣使用一些鲜为人知的 Pytorch 内置插件,这是一个替代解决方案。我们将使用torch.gather(类似地,您可以使用numpy.take 实现此目的)。

    torch.gather 背后的想法是基于两个相同形状的张量构建一个新的张量,其中包含索引(这里 ~ Y)和值(这里 ~ X)。

    执行的操作是Z[i][j][k] = X[i][Y[i][j][k]][k]

    由于X 的形状是(B, M, N)Y 的形状是(B, 1),我们正在寻找填充Y 内部的空白,这样Y 的形状就变成了(B, 1, N)

    这可以通过一些轴操作来实现:

    >>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1
    

    torch.gather 的实际调用将是:

    >>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])
    

    您可以通过添加[:, 0] 将其改造成(B, N)


    此功能在棘手的场景中非常有效...

    【讨论】:

    • 谢谢!虽然之前的解决方案有效,但它与批量大小无关。但是这个是!
    • 我很高兴我确实发布了答案,即使另一个答案已经被接受了!
    猜你喜欢
    • 2020-11-18
    • 1970-01-01
    • 1970-01-01
    • 2018-03-15
    • 2018-01-31
    • 2017-12-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多