【问题标题】:Advance indexing in Pytorch to get rid of nested for-loops在 Pytorch 中推进索引以摆脱嵌套的 for 循环
【发布时间】:2020-05-28 13:11:19
【问题描述】:

我有一种使用嵌套 for 循环的情况,但我想知道是否有更快的方法使用 Pytorch 中的一些高级索引来执行此操作。

我有一个名为 t 的张量:

t = torch.randn(3,8)
print(t)
tensor([[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160, -2.1152],
        [ 0.4681, -0.1577,  1.4437,  0.2660,  0.1665,  0.8744, -0.1435, -0.1116],
        [ 0.9318,  1.2590,  2.0050,  0.0537,  0.6181, -0.4128, -0.8411, -2.3160]])

我想创建一个新张量,它索引来自t 的值。 假设这些索引存储在变量indexes

indexes = [[(0, 1, 4, 5), (0, 1, 6, 7), (4, 5, 6, 7)],
           [(2, 3, 4, 5)],
           [(4, 5, 6, 7), (2, 3, 6, 7)]]

indexes 中的每个内部元组代表要从一行中获取的四个索引。

例如,基于这些索引,我的输出将是一个 6x4 维度的张量(6 是 indexes 中的元组总数,4 对应于一个元组中的一个值)

例如,这是我想做的:

#counting the number of tuples in indexes
count_instances = sum([1 for lst in indexes for tupl in lst])

#creating a zero output matrix 
final_tensor = torch.zeros(count_instances,4)

final_tensor[0] = t[0,indexes[0][0]]
final_tensor[1] = t[0,indexes[0][1]]
final_tensor[2] = t[0,indexes[0][2]]
final_tensor[3] = t[1,indexes[1][0]]
final_tensor[4] = t[2,indexes[2][0]]
final_tensor[5] = t[2,indexes[2][1]]

最终输出如下所示: 打印(final_tensor)

tensor([[-1.1258, -1.1524,  0.8487,  0.6920],
        [-1.1258, -1.1524, -0.3160, -2.1152],
        [ 0.8487,  0.6920, -0.3160, -2.1152],
        [ 1.4437,  0.2660,  0.1665,  0.8744],
        [ 0.6181, -0.4128, -0.8411, -2.3160],
        [ 2.0050,  0.0537, -0.8411, -2.3160]])

我创建了一个函数build_tensor(如下所示)以通过嵌套的 for 循环实现此目的,但我想知道是否有更快的方法通过 Pytorch 中的简单索引来实现。我想要一种更快的方法,因为我正在使用更大的索引和 t 大小进行数百次此操作。

有什么帮助吗?

def build_tensor(indexes, t):
    #count tuples
    count_instances = sum([1 for lst in indexes for tupl in lst])
    #create a zero tensor
    final_tensor = torch.zeros(count_instances,4)
    final_tensor_idx = 0

    for curr_idx, lst in enumerate(indexes):
        for tupl in lst:
            final_tensor[final_tensor_idx] = t[curr_idx,tupl]
            final_tensor_idx+=1
    return final_tensor

【问题讨论】:

    标签: python indexing pytorch numpy-ndarray


    【解决方案1】:

    您可以将索引排列到二维数组中,然后像这样一次性完成索引:

    rows = [(row,)*len(index_tuple) for row, row_indices in enumerate(indexes) for index_tuple in row_indices]
    columns = [index_tuple for row_indices in indexes for index_tuple in row_indices]
    final_tensor = t[rows, columns]
    

    【讨论】:

    • 这使它平均快了 5 倍。谢谢!
    猜你喜欢
    • 2020-08-04
    • 2014-02-12
    • 2020-02-21
    • 2018-11-30
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多