【问题标题】:How to resize a batch of images for use with Pytorch Linear Regression?如何调整一批图像的大小以用于 Pytorch 线性回归?
【发布时间】:2020-11-24 00:40:08
【问题描述】:

我正在尝试创建一个用于批量图像的简单线性回归神经网络。输入维度为[BatchSize, 3, Width, Height],第二个维度表示输入图像的RGB通道。

这是我对该回归模型的(失败的)尝试:

class LinearNet(torch.nn.Module):
    def __init__(self, Chn, W,H, nHidden):
        """
        Input: A [BatchSize x Channels x Width x Height] set of images
        Output: A fitted regression model with weights dimension : [Width x Height]
        """
        super(LinearNet, self).__init__()
        self.Chn = Chn
        self.W = W
        self.H = H
        self.hidden = torch.nn.Linear(Chn*W*H,nHidden)   # hidden layer
        self.predict = torch.nn.Linear(nHidden, Chn*W*H)   # output layer

    def forward(self, x):
        torch.reshape(x, (-1,self.Chn*self.W*self.H)) # FAILS here
        # x = x.resize(-1,self.Chn*self.W*self.H)  
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.predict(x)             # linear output
        x = x.resize(-1,self.Chn, self.W,self.H)
        return x

当发送一批尺寸为[128 x 3 x 96 x 128] 的图像时,在指示的行上失败:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (36864x128 and 36864x256)

如何正确操作矩阵维度以使用这些 pytorch 函数?

更新基于(已删除)评论,我已更新代码以使用torch.reshape

【问题讨论】:

  • 错误是说您的 x 的大小为 128x3x96x128 并且它具有 4718592 元素,但是,您试图将其仅转换为 36864 元素。你错过了*128

标签: python pytorch


【解决方案1】:

解决方案 1 作为一种可能的解决方案,您可以从输入 xx.shape[0] 获取批量大小,并在以后在 reshape 中使用它

import torch

batch = torch.zeros([128, 3, 96, 128], dtype=torch.float32)

# -1 will compute last dimension automatically
batch_upd = torch.reshape(batch, (batch.shape[0], -1))

print(batch_upd.shape)

这段代码的输出是

torch.Size([128, 36864])

解决方案 2 作为另一种可能的解决方案,您可以使用flatten

batch_upd = batch.flatten(start_dim=1)

将产生相同的输出

至于你的下一个问题,考虑通过修改后的forward代码:

def forward(self, x):
    x = x.flatten(1)  # shape: [B, C, W, H] -> [B, C*W*H]
    x = F.relu(self.hidden(x))      # activation function for hidden layer
    x = self.predict(x)             # linear output
    x = x.reshape((-1, self.Chn, self.W, self.H)) # shape: [B, C*W*H] -> [B, C, W, H]
    return x

下面是成功的使用例子:

ln = LinearNet(3, 96, 128, 256)
batch = torch.zeros((128, 3, 96, 128))
res = ln(batch)
print(res.shape)  # torch.Size([128, 3, 96, 128])

【讨论】:

  • flatten 不会导致上下文丢失吗?线性不应该将所有行压缩成一个大向量:它需要同时解决所有 128 个(批量大小)行。 reshape() 似乎是正确的:但它不同于 flatten 不?
  • @StephenBoesch 看起来flatten 是在底层使用reshape 实现的,你可以检查它here。但是,如果您不确定,请随意使用plain reshape ofc。
  • 啊,我错过了 flatten 的参数 (1) 使批量大小保持不变
猜你喜欢
  • 2014-04-19
  • 1970-01-01
  • 2019-04-04
  • 1970-01-01
  • 2018-12-22
  • 2018-04-17
  • 1970-01-01
  • 1970-01-01
  • 2012-06-03
相关资源
最近更新 更多