【问题标题】:Having trouble with input dimensions for Pytorch LSTM with torchtext带有torchtext的Pytorch LSTM的输入尺寸有问题
【发布时间】:2020-08-31 15:09:24
【问题描述】:

问题

我正在尝试使用 LSTM 构建文本分类器网络。我得到的错误是:

RuntimeError: Expected hidden[0] size (4, 600, 256), got (4, 64, 256)

详情

数据是json,长这样:

{"cat": "music", "desc": "I'm in love with the song's intro!", "sent": "h"}

我正在使用torchtext 来加载数据。

from torchtext import data
from torchtext import datasets

TEXT = data.Field(fix_length = 600)
LABEL = data.Field(fix_length = 10)

BATCH_SIZE = 64

fields = {
    'cat': ('c', LABEL),
    'desc': ('d', TEXT),
    'sent': ('s', LABEL),
}

我的 LSTM 是这样的

EMBEDDING_DIM = 64
HIDDEN_DIM = 256
N_LAYERS = 4

MyLSTM(
  (embedding): Embedding(11967, 64)
  (lstm): LSTM(64, 256, num_layers=4, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.3, inplace=False)
  (fc): Linear(in_features=256, out_features=8, bias=True)
  (sig): Sigmoid()
)

我最终得到inputslabels 的以下尺寸

batch = list(train_iterator)[0]
inputs, labels = batch
print(inputs.shape) # torch.Size([600, 64])
print(labels.shape) # torch.Size([100, 2, 64])

我初始化的隐藏张量看起来像:

hidden # [torch.Size([4, 64, 256]), torch.Size([4, 64, 256])]

问题

我试图了解每个步骤的尺寸应该是多少。 隐藏维度应该初始化为 (4, 600, 256) 还是 (4, 64, 256)?

【问题讨论】:

    标签: deep-learning pytorch lstm recurrent-neural-network torchtext


    【解决方案1】:

    nn.LSTM - Inputs 的文档解释了维度是什么:

    • h_0 形状 (num_layers * num_directions, batch, hidden_​​size):包含批次中每个元素的初始隐藏状态的张量。如果 LSTM 是双向的,num_directions 应该是 2,否则应该是 1。

    因此,您的隐藏状态的大小应该是 (4, 64, 256),所以您这样做是正确的。另一方面,您没有为输入提供正确的大小。

    • input 形状 (seq_len, batch, input_size):包含输入序列特征的张量。输入也可以是打包的可变长度序列。详情请参阅torch.nn.utils.rnn.pack_padded_sequence()torch.nn.utils.rnn.pack_sequence()

    虽然它说输入的大小需要为 (seq_len, batch, input_size),但您已经在 LSTM 中设置了 batch_first=True,它交换了 batchseq_len。因此,您的输入应该具有大小 (batch_size, seq_len, input_size),但事实并非如此,因为您的输入首先具有 seq_len (600) 和 batch em> second (64),这是 torchtext 中的默认值,因为这是更常见的表示,也符合 LSTM 的默认行为。

    你需要在你的 LSTM 中设置batch_first=False

    或者。如果您更喜欢将 batch 作为一般的第一个维度,torch.data.Field 也有 batch_first 选项。

    【讨论】:

      猜你喜欢
      • 2019-01-25
      • 2018-09-18
      • 2021-06-04
      • 1970-01-01
      • 2022-01-19
      • 2019-03-31
      • 2018-11-06
      • 2020-02-19
      • 2016-08-21
      相关资源
      最近更新 更多