【发布时间】:2018-07-19 02:37:07
【问题描述】:
我有一个尺寸为(30, 35, 49) 的张量。我想将其重塑为(30, 35, 512),以便能够与另一个形状为(30, 35, 512) 的张量相乘。
我想用(30, 35, 49) 维度对张量进行填充,以使其成为(30, 35, 512) 维度。
如何做到这一点?
【问题讨论】:
标签: pytorch
我有一个尺寸为(30, 35, 49) 的张量。我想将其重塑为(30, 35, 512),以便能够与另一个形状为(30, 35, 512) 的张量相乘。
我想用(30, 35, 49) 维度对张量进行填充,以使其成为(30, 35, 512) 维度。
如何做到这一点?
【问题讨论】:
标签: pytorch
最简单的解决方案是使用您的填充值和目标尺寸分配一个张量,并分配您拥有数据的部分:
target = torch.zeros(30, 35, 512)
source = torch.ones(30, 35, 49)
target[:, :, :49] = source
请注意,不能保证用零填充您的张量,然后将其与另一个张量相乘最终有意义,这取决于您。
【讨论】:
虽然 @nemo 的解决方案工作正常,但有一个 pytorch 内部例程 torch.nn.functional.pad,它具有相同的功能 - 并且具有 torch.ones(*sizes)*pad_value 解决方案没有的几个属性(即其他形式的填充,如反射填充或复制padding ...它还检查一些与渐变相关的属性):
import torch.nn.functional as F
source = torch.rand((5,10))
# now we expand to size (7, 11) by appending a row of 0s at pos 0 and pos 6,
# and a column of 0s at pos 10
result = F.pad(input=source, pad=(0, 1, 1, 1), mode='constant', value=0)
参数的语义是:
input:源张量,pad:长度列表2 * len(source.shape) 的形式(开始最后一个轴,结束最后一个轴,开始第 2 个到最后一个轴,结束第 2 个到最后一个轴,从第 3 个到最后一个轴等),说明应该有多少维度被添加到每个轴的开头和结尾,mode:'constant'、'reflect' 或 'replicate'。默认值:'constant' 用于不同类型的填充value 用于常量填充。【讨论】:
torch.nn.ConstantPad1d 是一个可能更清晰、更适合这个问题的模块,例如
import torch
from torch import nn
x = torch.ones(30, 35, 49)
padded = nn.ConstantPad1d((0, 512 - 49), 0)(x)
【讨论】:
这里的想法是使用 torch.cat 用您想要的张量填充该特定维度。这个例子应该更清楚。
In [1]: import torch
In [2]: a = torch.randn(30, 35, 49)
In [3]: b = torch.randn(30, 35, 512)
In [4]: padder = torch.zeros(30,35,512 - 49)
In [5]: padded_a = torch.cat([a,padder], dim = 2) # Choose your desired dim
In [6]: padded_a.shape
Out[6]: torch.Size([30, 35, 512])
In [7]: target = torch.randn(30,35,512)
In [8]: target = torch.cat([target,padded_a], dim = 2)
In [9]: target.shape
Out[9]: torch.Size([30, 35, 1024])
【讨论】:
只是想说明@ghchoi 给出的答案。因为我在跟踪它时遇到了一点麻烦。
由于内核大小限制,我想将尺寸为 (N,1,28,28) 的标准 mnist 中的图像拟合到 LeNet(早在 1998 年提出),预计输入的形状为 (N,1,32,32)。所以假设我们尝试通过填充来缓解这个问题。
在填充单个图像之前,它的大小为(1,28,28).
因此我们有三个维度。
在 padding 之后,创建一个大小为(1,32,32) 的图像。注意pad=(2,2,2,2,0,0)
这是因为我在第一个 (2,2) 之前和之后在 x 轴上添加了两个零,在 yaxis (2,2) 之后添加了两个零,因此单独留下了通道列 (0,0)。 value 表示填充为 0。
谢谢!
【讨论】: