【问题标题】:Extract elements from .npy file, convert them to PyTorch tensors从 .npy 文件中提取元素,将它们转换为 PyTorch 张量
【发布时间】:2019-04-18 22:24:39
【问题描述】:

我阅读了一个仅包含图像标签的 .npy 文件。标签以字典格式存储。我需要将其转换为张量数组。但是我无法从文件返回的对象中提取另一个元素,该对象是 numpy.ndarray 类型。


import numpy as np
data = np.load('/content/drive/My Drive/targets.npy')
print(data.item())


{0: array(5), 1: array(0), 2: array(4), 3: array(1), 4: array(9), 5: array(2), 6: array(1), 7: array(3)}

print(data[()].values())

dict_values([array(5), array(0), array(4), array(1), array(9), array(2), array(1), array(3)])

我想创建一个张量数组。

提前致谢。

【问题讨论】:

    标签: python numpy multidimensional-array deserialization pytorch


    【解决方案1】:

    假设您的 data 是一个字典:

    In [59]: dct = {0: np.array([5]), 1: np.array([0]), 2: np.array([4]), 
                    3: np.array([1]), 4: np.array([9]), 5: np.array([2]), 
                    6: np.array([1]), 7: np.array([3])}
    

    您可以使用 numpy.concatenate() 包裹在 torch.tensor() 中来获取张量:

    In [63]: torch.tensor(np.concatenate(list(dct.values())))
    Out[63]: tensor([5, 0, 4, 1, 9, 2, 1, 3])
    

    此外,如果您希望键和值都堆叠在单个 2D 张量中,请使用 torch.cat()

    # tensor with just keys
    In [86]: tk = torch.tensor(list(dct.keys()))
    In [87]: tk
    Out[87]: tensor([0, 1, 2, 3, 4, 5, 6, 7])
    
    # tensor with just values
    In [88]: tv = torch.tensor(np.concatenate(list(dct.values())))
    In [89]: tv
    Out[89]: tensor([5, 0, 4, 1, 9, 2, 1, 3])
    
    # horizontally stack them into a single 2D tensor
    In [85]: torch.cat((tk[:, None], tv[:, None]), dim=1)
    Out[85]: 
    tensor([[0, 5],
            [1, 0],
            [2, 4],
            [3, 1],
            [4, 9],
            [5, 2],
            [6, 1],
            [7, 3]])
    

    经过一系列的cmets,我现在明白了你的问题,这里是解决它的方法:

    In [48]: data_item = {0: np.array(5), 1: np.array(0), 2: np.array(4), 
                          3: np.array(1), 4: np.array(9), 5: np.array(2),
                          6: np.array(1), 7: np.array(3)}
    
    # convert keys to an 1D tensor
    In [53]: tk = torch.tensor(list(data_item.keys()))
    
    In [54]: tk
    Out[54]: tensor([0, 1, 2, 3, 4, 5, 6, 7])
    

    由于您将值作为 0D 数组(即标量),我们需要从中提取元素。为此,我们可以在 map 旁边使用 lambda 函数,它的工作是将 lambda 函数应用于可迭代对象(这里:data_item.values())并给我们元素。这些可以传递给torch.tensor 以获得所需的一维张量。

    # convert values to an 1D tensor
    In [57]: tv = torch.tensor(list(map(lambda a: a.item(), data_item.values())))
    
    In [58]: tv
    Out[58]: tensor([5, 0, 4, 1, 9, 2, 1, 3])
    
    # horizontally stack them into a single 2D tensor, if needed
    In [85]: torch.cat((tk[:, None], tv[:, None]), dim=1)
    Out[85]: 
    tensor([[0, 5],
            [1, 0],
            [2, 4],
            [3, 1],
            [4, 9],
            [5, 2],
            [6, 1],
            [7, 3]])
    

    【讨论】:

    • 问题是,它不是字典。 data.values() 给出了这个: AttributeError: 'numpy.ndarray' object has no attribute 'values'
    • 它是一个字典,作为单个元素存储在 ndarray 中
    • @aneeshaasc 你能举个简单的例子说明data 的样子吗?
    • {0:数组(5),1:数组(0),2:数组(4),3:数组(1),4:数组(9),5:数组(2) , 6: 数组(1), 7: 数组(3)}
    • 别担心,我想我找到了办法。
    【解决方案2】:

    在@kmario23 的指导下,以下内容对我有用

    import numpy as np
    data = np.load('/content/drive/My Drive/targets.npy')
    print(data.item())
    
    {0: array(5), 1: array(0), 2: array(4), 3: array(1), 4: array(9), 5: array(2), 6: array(1), 7: array(3)}
    # data is a 0-d numpy.ndarray that contains a dictionary. 
    
    print(list(data[()].values()))
    
    [array(5),
     array(0),
     array(4),
     array(1),
     array(9),
     array(2),
     array(1),
     array(3),
     array(1),
     array(4),
     array(3)]
    
    # torch.Tensor(5) gives tensor([2.0581e-35, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
    # torch.tensor(5) gives 5
    # unsure of why the difference exists..
    
    Labels = torch.stack([torch.tensor(i) for i in list_of_labels_array_form])
    
    print(Labels)
    
    tensor([5, 0, 4,  ..., 2, 5, 0])
    

    【讨论】:

      猜你喜欢
      • 2022-10-17
      • 1970-01-01
      • 1970-01-01
      • 2019-07-29
      • 1970-01-01
      • 1970-01-01
      • 2020-08-05
      • 2021-11-25
      相关资源
      最近更新 更多