【问题标题】:PyTorch: How to get the shape of a Tensor as a list of intPyTorch:如何将张量的形状作为 int 列表获取
【发布时间】:2018-03-31 06:35:20
【问题描述】:

在 numpy 中,V.shape 给出了一个维数为 V 的整数元组。

在 tensorflow 中V.get_shape().as_list() 给出了 V 维数的整数列表。

在 pytorch 中,V.size() 给出了一个 size 对象,但我如何将其转换为整数?

【问题讨论】:

    标签: python pytorch tensor


    【解决方案1】:

    对于 PyTorch v1.0 及以上版本:

    >>> import torch
    >>> var = torch.tensor([[1,0], [0,1]])
    
    # Using .size function, returns a torch.Size object.
    >>> var.size()
    torch.Size([2, 2])
    >>> type(var.size())
    <class 'torch.Size'>
    
    # Similarly, using .shape
    >>> var.shape
    torch.Size([2, 2])
    >>> type(var.shape)
    <class 'torch.Size'>
    

    您可以将任何 torch.Size 对象转换为原生 Python 列表:

    >>> list(var.size())
    [2, 2]
    >>> type(list(var.size()))
    <class 'list'>
    

    在 PyTorch v0.3 和 0.4 中:

    只需list(var.size()),例如:

    >>> import torch
    >>> from torch.autograd import Variable
    >>> from torch import IntTensor
    >>> var = Variable(IntTensor([[1,0],[0,1]]))
    
    >>> var
    Variable containing:
     1  0
     0  1
    [torch.IntTensor of size 2x2]
    
    >>> var.size()
    torch.Size([2, 2])
    
    >>> list(var.size())
    [2, 2]
    

    【讨论】:

    • 不是调用list,Size 类是否有某种属性我可以直接访问以获取元组或列表形式的形状?
    • 您可以直接访问大小,而无需将其投射到列表中。 OP 要求提供一个列表,因此需要类型转换。
    • 有用的输入。谢谢。我可以加一分。结果仍然是列表形式。我们可以简单地编写以将输出作为整数:list(var.size())[0]list(var.size())[1]
    • list(var.size()) 的元素类型为tensor,而不是int。如何处理?我需要int
    【解决方案2】:

    如果您喜欢NumPyish 语法,那么tensor.shape

    In [3]: ar = torch.rand(3, 3)
    
    In [4]: ar.shape
    Out[4]: torch.Size([3, 3])
    
    # method-1
    In [7]: list(ar.shape)
    Out[7]: [3, 3]
    
    # method-2
    In [8]: [*ar.shape]
    Out[8]: [3, 3]
    
    # method-3
    In [9]: [*ar.size()]
    Out[9]: [3, 3]
    

    P.S.:请注意tensor.shapetensor.size() 的别名,尽管tensor.shape 是相关张量的一个属性,而tensor.size() 是一个函数。

    【讨论】:

    • 哪部分代码只能在GPU机器上运行? tensor.shape?
    【解决方案3】:

    以前的答案为您提供了 torch.Size 的列表 这是获取整数列表的方法

    listofints = [int(x) for x in tensor.shape]
    

    【讨论】:

      【解决方案4】:

      torch.Size 对象是 a subclass of tuple,并继承其通常的属性,例如它可以被索引:

      v = torch.tensor([[1,2], [3,4]])
      v.shape[0]
      >>> 2
      

      注意它的条目已经是int类型。


      如果你真的想要一个列表,只需使用 list 构造函数,就像任何其他可迭代对象一样:

      list(v.shape)
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2017-04-01
        • 1970-01-01
        • 2020-08-05
        • 2019-07-29
        • 2019-02-21
        • 2020-10-04
        • 1970-01-01
        相关资源
        最近更新 更多