【问题标题】:TypeError: forward() missing 1 required positional argument with Tensorboard PyTorchTypeError: forward() 缺少 1 个必需的位置参数与 Tensorboard PyTorch
【发布时间】:2020-10-08 17:08:21
【问题描述】:

我正在尝试使用以下代码将我的模型写入tensorboard

model = SimpleLSTM(4, HIDDEN_DIM, HIDDEN_LAYERS, 1, BATCH_SIZE, device)
writer = tb.SummaryWriter(log_dir=tb_path)
sample_data = iter(trainloader).next()[0]
writer.add_graph(model, sample_data.to(device))

我收到错误:TypeError: forward() missing 1 required positional argument: 'batch_size'

我的模型如下所示:

class SimpleLSTM(nn.Module):

    def __init__(self, input_dims, hidden_units, hidden_layers, out, batch_size, device):
        super(SimpleLSTM, self).__init__()
        self.input_dims = input_dims
        self.hidden_units = hidden_units
        self.hidden_layers = hidden_layers
        self.batch_size = batch_size
        self.device = device
        self.lstm = nn.LSTM(self.input_dims, self.hidden_units, self.hidden_layers,
                            batch_first=True, bidirectional=False)
        self.output_layer = nn.Linear(self.hidden_units, out)

    def init_hidden(self, batch_size):

        hidden = torch.rand(self.hidden_layers, batch_size, self.hidden_units, device=self.device, dtype=torch.float32)
        cell = torch.rand(self.hidden_layers, batch_size, self.hidden_units, device=self.device, dtype=torch.float32)
        hidden = nn.init.xavier_normal_(hidden)
        cell = nn.init.xavier_normal_(cell)
        return (hidden, cell)

    def forward(self, input, batch_size):
        hidden = self.init_hidden(batch_size)  incomplete batch
        lstm_out, (h_n, c_n) = self.lstm(input, hidden)
        raw_out = self.output_layer(h_n[-1])
        return raw_out

如何将此模型写入 TensorBoard?

【问题讨论】:

    标签: python pytorch lstm tensorboard


    【解决方案1】:

    您的模型有两个参数inputbatch_size,但您只提供一个参数供add_graph 调用您的模型。

    输入(add_graph 的第二个参数)应该是带有 inputbatch_size 的元组:

    writer.add_graph(model, (sample_data.to(device), BATCH_SIZE))
    

    您实际上不需要向 forward 方法提供批量大小,因为您可以从输入中推断出它。由于您的 LSTM 使用batch_first=True,这意味着输入需要具有大小 [batch_size, seq_len, num_features],因此第一个维度的大小是当前的批量大小。

    def forward(self, input):
        batch_size = input.size(0)
        # ...
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2013-07-06
      • 2019-06-25
      • 2013-10-02
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多