【问题标题】:How to index intermediate dimension with an index tensor in pytorch?如何在pytorch中使用索引张量索引中间维度?
【发布时间】:2021-02-11 21:07:21
【问题描述】:

如何索引具有 n 维的张量 t 和 m index 张量,以便保留 t 的最后一个维度?对于维度 m 之前的所有维度,index 张量的形状等于张量 t。或者换句话说,我想索引张量的中间维度,同时保留所选索引的所有以下维度。

例如,假设我们有两个张量:

t = torch.randn([3, 5, 2]) * 10
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()

与 t:

tensor([[[ 15.2165,  -7.9702],
         [  0.6646,   5.2844],
         [-22.0657,  -5.9876],
         [ -9.7319,  11.7384],
         [  4.3985,  -6.7058]],

        [[-15.6854, -11.9362],
         [ 11.3054,   3.3068],
         [ -4.7756,  -7.4524],
         [  5.0977, -17.3831],
         [  3.9152, -11.5047]],

        [[ -5.4265, -22.6456],
         [  1.6639,  10.1483],
         [ 13.2129,   3.7850],
         [  3.8543,  -4.3496],
         [ -8.7577, -12.9722]]])

然后我想要的输出将具有形状 (3, 2, 2) 并且是:

tensor([[[  0.6646,   5.2844],
         [ -9.7319,  11.7384]],
        [[-15.6854, -11.9362],
         [  3.9152, -11.5047]],
        [[  3.8543,  -4.3496],
         [ 13.2129,   3.7850]]])

另一个例子是我有一个形状为(40, 10, 6, 2) 的张量t 和一个形状为(40, 10, 3) 的索引张量。这应该查询张量 t 的维度 3,并且预期的输出形状将是 (40, 10, 3, 2)

如何在不使用循环的情况下以通用方式实现这一点?

【问题讨论】:

    标签: python indexing pytorch tensor


    【解决方案1】:

    在这种情况下,您可以这样做:

    t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
    

    完整代码:

    import torch
    
    t = torch.tensor([[[ 15.2165,  -7.9702],
                       [  0.6646,   5.2844],
                       [-22.0657,  -5.9876],
                       [ -9.7319,  11.7384],
                       [  4.3985,  -6.7058]],
                      [[-15.6854, -11.9362],
                       [ 11.3054,   3.3068],
                       [ -4.7756,  -7.4524],
                       [  5.0977, -17.3831],
                       [  3.9152, -11.5047]],
                      [[ -5.4265, -22.6456],
                       [  1.6639,  10.1483],
                       [ 13.2129,   3.7850],
                       [  3.8543,  -4.3496],
                       [ -8.7577, -12.9722]]])
    
    index = torch.tensor([[1, 3],[0,4],[3,2]]).long()
    
    output = t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]
    
    # tensor([[[  0.6646,   5.2844],
    #          [ -9.7319,  11.7384]],
    # 
    #         [[-15.6854, -11.9362],
    #          [  3.9152, -11.5047]],
    # 
    #         [[  3.8543,  -4.3496],
    #          [ 13.2129,   3.7850]]])
    

    【讨论】:

    • 但是这个解决方案不是通用的,对吧?如果我的 t 为 5 维而我的索引为 3 或 4 怎么办?
    • @Chris 那么您的索引定义不明确...请用另一个示例更新您的问题。确保您的索引有定义。例如,我必须在回答中引入一个假设,因为不清楚应该如何使用您的索引,除非您查看输出,然后会注意到索引不完整。
    • 我的问题很笼统,这就是为什么我使用 n 和 m 作为维度。这个例子当然是具体的。您有 n 个维度,并希望在中间维度 m(小于 n)中查询该张量。索引张量定义为维度 m,并且 m 之前的所有维度都等于目标张量。我可以添加这些信息,我认为它很明显。
    • 我扩展了我的描述,希望这能让我的意思更清楚!
    • @Chris 您的索引仍然不明确。例如,对于 m=3 和 n=2,如何定义 output[i,j,k]?也许我错过了什么。
    猜你喜欢
    • 2019-02-05
    • 2019-09-15
    • 2020-09-26
    • 2019-05-31
    • 2021-11-04
    • 2021-12-16
    • 2020-03-28
    • 2019-11-26
    • 2018-06-08
    相关资源
    最近更新 更多