【问题标题】:How to do product of matrices in PyTorch如何在 PyTorch 中做矩阵的乘积
【发布时间】:2017-11-15 10:23:23
【问题描述】:

在 numpy 中,我可以像这样进行简单的矩阵乘法:

a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))

但是,当我尝试使用 PyTorch 张量时,这不起作用:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())

print(b)
print(b.size())

print(torch.dot(a, b))

此代码引发以下错误:

RuntimeError:张量大小不一致 /Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503

知道如何在 PyTorch 中进行矩阵乘法吗?

【问题讨论】:

    标签: python matrix pytorch


    【解决方案1】:

    您可以使用“@”来计算 pytorch 中两个张量之间的点积。

    a = torch.tensor([[1,2],
                      [3,4]])
    b = torch.tensor([[5,6],
                      [7,8]])
    c = a@b #For dot product
    c
    
    d = a*b #For elementwise multiplication 
    d
    

    【讨论】:

      【解决方案2】:

      你正在寻找

      torch.mm(a,b)
      

      请注意,torch.dot() 的行为与 np.dot() 不同。关于什么是可取的here 进行了一些讨论。具体来说,torch.dot()ab 都视为一维向量(无论它们的原始形状如何)并计算它们的内积。抛出错误,因为这种行为使您的 a 成为长度为 6 的向量,而您的 b 成为长度为 2 的向量;因此无法计算它们的内积。对于 PyTorch 中的矩阵乘法,请使用 torch.mm()。 Numpy 的np.dot() 相比之下更加灵活;它计算一维数组的内积并为二维数组执行矩阵乘法。

      根据普遍需求,如果两个参数都是2D,函数torch.matmul 将执行矩阵乘法,如果两个参数都是1D,则计算它们的点积。对于此类维度的输入,其行为与np.dot 相同。它还可以让您批量进行广播或matrix x matrixmatrix x vectorvector x vector 操作。有关详细信息,请参阅其docs

      # 1D inputs, same as torch.dot
      a = torch.rand(n)
      b = torch.rand(n)
      torch.matmul(a, b) # torch.Size([])
      
      # 2D inputs, same as torch.mm
      a = torch.rand(m, k)
      b = torch.rand(k, j)
      torch.matmul(a, b) # torch.Size([m, j])
      

      【讨论】:

      • 既然这是公认的答案,我认为你应该包括torch.matmul。它对一维数组执行点积,对二维数组执行矩阵乘法。
      【解决方案3】:

      如果您想进行矩阵(2 阶张量)乘法,您可以通过四种等效方式进行:

      AB = A.mm(B) # computes A.B (matrix multiplication)
      # or
      AB = torch.mm(A, B)
      # or
      AB = torch.matmul(A, B)
      # or, even simpler
      AB = A @ B # Python 3.5+
      

      有一些微妙之处。来自PyTorch documentation

      torch.mm 不广播。对于广播矩阵产品, 参见 torch.matmul()。

      例如,您不能将两个一维向量与torch.mm 相乘,也不能将批处理矩阵相乘(等级 3)。为此,您应该使用更通用的torch.matmul。有关torch.matmul 的广播行为的详细列表,请参阅documentation

      对于逐元素乘法,您可以简单地做(如果 A 和 B 具有相同的形状)

      A * B # element-wise matrix multiplication (Hadamard product)
      

      【讨论】:

      • 喜欢单字符@ 运算符。 w @ x 将是我的转到
      【解决方案4】:

      使用torch.mm(a, b)torch.matmul(a, b)
      两者都是一样的。

      >>> torch.mm
      <built-in method mm of type object at 0x11712a870>
      >>> torch.matmul
      <built-in method matmul of type object at 0x11712a870>
      

      还有一个可能很高兴知道的选项。 那是@ 运算符。 @Simon H.

      >>> a = torch.randn(2, 3)
      >>> b = torch.randn(3, 4)
      >>> a@b
      tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
              [ 0.8699, -0.3445,  1.4122, -0.5826]])
      >>> a.mm(b)
      tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
              [ 0.8699, -0.3445,  1.4122, -0.5826]])
      >>> a.matmul(b)
      tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
              [ 0.8699, -0.3445,  1.4122, -0.5826]])    
      

      三者给出相同的结果。

      相关链接:
      Matrix multiplication operator
      PEP 465 -- A dedicated infix operator for matrix multiplication

      【讨论】:

      • torch.mm(a,b)torch.matmul(a,b)a@b 是否等效?我在 @ 运算符上找不到任何文档。
      • 是的,似乎没有任何关于 @ 运算符的文档。但是,文档中有几个符号,其中包括 @,它们给出了矩阵乘法的语义。所以我认为@运算符已经被PyTorch重载了矩阵乘法的意思。
      • 添加了@运算符的链接。
      猜你喜欢
      • 2018-03-02
      • 2019-04-21
      • 2019-04-29
      • 2023-04-08
      • 2018-11-22
      • 1970-01-01
      • 1970-01-01
      • 2019-03-05
      • 1970-01-01
      相关资源
      最近更新 更多