【问题标题】:Hidden units saturate in a seq2seq model in PyTorchPyTorch 的 seq2seq 模型中的隐藏单元饱和
【发布时间】:2017-10-20 07:55:35
【问题描述】:

我正在尝试在PyTorch 中编写一个非常简单的机器翻译玩具示例。为了简单的问题,我把机器翻译任务变成了这个:

给定一个随机序列 ([4, 8, 9 ...]),预测其元素加 1 ([5, 9, 10, ...]) 的序列。 Id:0, 1, 2 将分别用作pad, bos, eos

在我的机器翻译任务中,我在这个玩具任务中观察到了同样的问题。为了调试,我使用了非常小的数据量n_data = 50,发现模型不会甚至过度拟合这些数据。查看模型,我发现encoder/decoder soon 的隐藏状态变得饱和,即所有 处于隐藏状态的单元变得非常接近1/-1由于tanh

-0.8987  0.9634  0.9993  ...  -0.8930 -0.4822 -0.9960
-0.9673  1.0000 -0.8007  ...   0.9929 -0.9992  0.9990
-0.9457  0.9290 -0.9260  ...  -0.9932  0.9851  0.9980
          ...             ⋱             ...
-0.9995  0.9997 -0.9350  ...  -0.9820 -0.9942 -0.9913
-0.9951  0.9488 -0.8894  ...  -0.9842 -0.9895 -0.9116
-0.9991  0.9769 -0.5871  ...   0.7557  0.9049  0.9881

另外,无论我如何调整学习率,或者将单元切换到 RNN/LSTM/GRU 单元,即使使用50 测试样本,损失值似乎也有一个下限。随着数据的增多,模型似乎根本没有收敛。

step: 0, loss: 2.313938
step: 10, loss: 1.435780
step: 20, loss: 0.779704
step: 30, loss: 0.395590
step: 40, loss: 0.281261
...
step: 480, loss: 0.231419
step: 490, loss: 0.231410

当我使用tensorflow 时,我可以轻松地使用 seq2seq 模型过拟合这样的数据集,并且损失值非常小。

这是我尝试过的:

  1. 手动将嵌入初始化为非常小的数字;
  2. 将渐变剪裁为固定范数,例如 1e-2、2、3、5、10;
  3. 在计算损失时排除填充索引(通过将ignore_index 添加到NLLLoss)。

我尝试过的所有方法都没有解决问题。

我怎样才能摆脱这个?任何帮助将不胜感激。

这里是代码,为了更好的阅读体验,它在gist

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

np.random.seed(0)
torch.manual_seed(0)

_RECURRENT_FN_MAPPING = {
    'rnn': torch.nn.RNN,
    'gru': torch.nn.GRU,
    'lstm': torch.nn.LSTM,
}


def get_recurrent_cell(n_inputs,
                       num_units,
                       num_layers,
                       type_,
                       dropout=0.0,
                       bidirectional=False):
    cls = _RECURRENT_FN_MAPPING.get(type_)

    return cls(
        n_inputs,
        num_units,
        num_layers,
        dropout=dropout,
        bidirectional=bidirectional)


class Recurrent(nn.Module):

    def __init__(self,
                 num_units,
                 num_layers=1,
                 unit_type='gru',
                 bidirectional=False,
                 dropout=0.0,
                 embedding=None,
                 attn_type='general'):
        super(Recurrent, self).__init__()

        num_inputs = embedding.weight.size(1)
        self._num_inputs = num_inputs
        self._num_units = num_units
        self._num_layers = num_layers
        self._unit_type = unit_type
        self._bidirectional = bidirectional
        self._dropout = dropout
        self._embedding = embedding
        self._attn_type = attn_type
        self._cell_fn = get_recurrent_cell(num_inputs, num_units, num_layers,
                                           unit_type, dropout, bidirectional)

    def init_hidden(self, batch_size):
        direction = 1 if not self._bidirectional else 2
        h = Variable(
            torch.zeros(direction * self._num_layers, batch_size,
                        self._num_units))
        if self._unit_type == 'lstm':
            return (h, h.clone())
        else:
            return h

    def forward(self, x, h, len_x):
        # Sort by sequence lengths
        sorted_indices = np.argsort(-len_x).tolist()
        unsorted_indices = np.argsort(sorted_indices).tolist()
        x = x[:, sorted_indices]
        h = h[:, sorted_indices, :]
        len_x = len_x[sorted_indices].tolist()

        embedded = self._embedding(x)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, len_x)

        if self._unit_type == 'lstm':
            o, (h, c) = self._cell_fn(packed, h)
            o, _ = torch.nn.utils.rnn.pad_packed_sequence(o)
            return (o[:, unsorted_indices, :], (h[:, unsorted_indices, :],
                                                c[:, unsorted_indices, :]))
        else:
            o, hh = self._cell_fn(packed, h)
            o, _ = torch.nn.utils.rnn.pad_packed_sequence(o)
            return (o[:, unsorted_indices, :], hh[:, unsorted_indices, :])


class Encoder(Recurrent):
    pass


class Decoder(Recurrent):
    pass


class Seq2Seq(nn.Module):

    def __init__(self, encoder, decoder, num_outputs):
        super(Seq2Seq, self).__init__()
        self._encoder = encoder
        self._decoder = decoder
        self._out = nn.Linear(decoder._num_units, num_outputs)

    def forward(self, x, y, h, len_x, len_y):
        # Encode
        _, h = self._encoder(x, h, len_x)
        # Decode
        o, h = self._decoder(y, h, len_y)
        # Project
        o = self._out(o)

        return F.log_softmax(o)


def load_data(size,
              min_len=5,
              max_len=15,
              min_word=3,
              max_word=100,
              epoch=10,
              batch_size=64,
              pad=0,
              bos=1,
              eos=2):
    src = [
        np.random.randint(min_word, max_word - 1,
                          np.random.randint(min_len, max_len)).tolist()
        for _ in range(size)
    ]
    tgt_in = [[bos] + [xi + 1 for xi in x] for x in src]
    tgt_out = [[xi + 1 for xi in x] + [eos] for x in src]

    def _pad(batch):
        max_len = max(len(x) for x in batch)
        return np.asarray(
            [
                np.pad(
                    x, (0, max_len - len(x)),
                    mode='constant',
                    constant_values=pad) for x in batch
            ],
            dtype=np.int64)

    def _len(batch):
        return np.asarray([len(x) for x in batch], dtype=np.int64)

    for e in range(epoch):
        batch_start = 0

        while batch_start < size:
            batch_end = batch_start + batch_size

            s, ti, to = (src[batch_start:batch_end],
                         tgt_in[batch_start:batch_end],
                         tgt_out[batch_start:batch_end])
            lens, lent = _len(s), _len(ti)

            s, ti, to = _pad(s).T, _pad(ti).T, _pad(to).T

            yield (Variable(torch.LongTensor(s)),
                   Variable(torch.LongTensor(ti)),
                   Variable(torch.LongTensor(to)), lens, lent)

            batch_start += batch_size


def print_sample(x, y, yy):
    x = x.data.numpy().T
    y = y.data.numpy().T
    yy = yy.data.numpy().T

    for u, v, w in zip(x, y, yy):
        print('--------')
        print('S: ', u)
        print('T: ', v)
        print('P: ', w)


n_data = 50
min_len = 5
max_len = 10
vocab_size = 101
n_samples = 5

epoch = 100000
batch_size = 32
lr = 1e-2
clip = 3

emb_size = 50
hidden_size = 50
num_layers = 1
max_length = 15

src_embed = torch.nn.Embedding(vocab_size, emb_size)
tgt_embed = torch.nn.Embedding(vocab_size, emb_size)

eps = 1e-3
src_embed.weight.data.uniform_(-eps, eps)
tgt_embed.weight.data.uniform_(-eps, eps)

enc = Encoder(hidden_size, num_layers, embedding=src_embed)
dec = Decoder(hidden_size, num_layers, embedding=tgt_embed)
net = Seq2Seq(enc, dec, vocab_size)

optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = torch.nn.NLLLoss()

loader = load_data(
    n_data,
    min_len=min_len,
    max_len=max_len,
    max_word=vocab_size,
    epoch=epoch,
    batch_size=batch_size)

for i, (x, yin, yout, lenx, leny) in enumerate(loader):
    net.train()
    optimizer.zero_grad()

    logits = net(x, yin, enc.init_hidden(x.size()[1]), lenx, leny)
    loss = criterion(logits.view(-1, vocab_size), yout.contiguous().view(-1))

    loss.backward()

    torch.nn.utils.clip_grad_norm(net.parameters(), clip)
    optimizer.step()

    if i % 10 == 0:
        print('step: {}, loss: {:.6f}'.format(i, loss.data[0]))

    if i % 200 == 0 and i > 0:
        net.eval()
        x, yin, yout, lenx, leny = (x[:, :n_samples], yin[:, :n_samples],
                                    yout[:, :n_samples], lenx[:n_samples],
                                    leny[:n_samples])
        outputs = net(x, yin, enc.init_hidden(x.size()[1]), lenx, leny)
        _, preds = torch.max(outputs, 2)
        print_sample(x, yout, preds)

【问题讨论】:

    标签: lstm recurrent-neural-network pytorch


    【解决方案1】:

    我认为您没有在 tanh 范围内操作,因为您的输入非常大/小,因此会导致 1/-1 值。例如 tanh(5)=0.999 一个 tanh(-5)=-0.999。尝试在 tanh 可以处理的范围内对数据进行标准化,而不会出现极端情况(例如在 +1 到 -1 之间)。如果激活函数是 sigmoid,最好将 0 到 1 之间的数据归一化。

    【讨论】:

    • 我尝试将输入嵌入初始化为非常小的数字(大约 1e-4),但没有任何改变...
    • 我认为您没有使数据更接近 0,而是将数据规范化在 1 到 -1 的范围内。您可以使用 min-max 归一化来做到这一点。
    • 我试过了,但我没有更新问题。很抱歉。
    • 不清楚是否对1到-1范围内的数据进行归一化处理?如果是,那么请更新问题以便人们更好地理解您的问题。谢谢。
    • 你是对的。我已经更新了问题和 gist 上的代码。
    猜你喜欢
    • 2019-11-02
    • 1970-01-01
    • 1970-01-01
    • 2018-08-23
    • 1970-01-01
    • 2013-02-17
    • 2018-07-18
    • 2011-10-19
    • 2016-09-06
    相关资源
    最近更新 更多