由于语料短,训练时间也短,模型性能不好,以下演示过程。

语料链接:https://pan.baidu.com/s/1wpP4t_GSyPAD6HTsIoGPZg
提取码:jqq8

数据格式如图(先英文,再空格,再繁体中文):

Pytorch-seq2seq机器翻译模型(不含attention和含attention两个版本)

以下代码运行在Google Colab上。 

导包:

 1 import os
 2 import sys
 3 import math
 4 from collections import Counter
 5 import numpy as np
 6 import random
 7 
 8 import torch
 9 import torch.nn as nn
10 import torch.nn.functional as F
11 
12 import nltk
13 nltk.download('punkt')

1.数据预处理

1.1读入中英文数据

  • 英文使用nltk的word tokenizer来分词,并且使用小写字母
  • 中文直接使用单个汉字作为基本单元
 1 def load_data(in_file):
 2     cn = []
 3     en = []
 4     num_examples = 0
 5     with open(in_file, 'r') as f:
 6         for line in f:
 7             line = line.strip().split("\t")
 8             
 9             en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
10             cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
11     return en, cn
12 
13 train_file = "nmt/en-cn/train.txt"
14 dev_file = "nmt/en-cn/dev.txt"
15 train_en, train_cn = load_data(train_file)
16 dev_en, dev_cn = load_data(dev_file)

查看返回的数据内容:

1 print(dev_en[:2])
2 print(dev_cn[:2])

[['BOS', 'she', 'put', 'the', 'magazine', 'on', 'the', 'table', '.', 'EOS'], ['BOS', 'hey', ',', 'what', 'are', 'you', 'doing', 'here', '?', 'EOS']]

[['BOS', '她', '把', '雜', '誌', '放', '在', '桌', '上', '。', 'EOS'], ['BOS', '嘿', ',', '你', '在', '這', '做', '什', '麼', '?', 'EOS']]

1.2构建单词表

 1 UNK_IDX = 0
 2 PAD_IDX = 1
 3 def build_dict(sentences, max_words=50000):
 4     word_count = Counter()
 5     for sentence in sentences:
 6         for s in sentence:
 7             word_count[s] += 1
 8     ls = word_count.most_common(max_words)
 9     total_words = len(ls) + 2
10     word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
11     word_dict["UNK"] = UNK_IDX
12     word_dict["PAD"] = PAD_IDX
13     return word_dict, total_words      #total_words所有单词数,最大50002
14 
15 en_dict, en_total_words = build_dict(train_en)
16 cn_dict, cn_total_words = build_dict(train_cn)
17 inv_en_dict = {v: k for k, v in en_dict.items()}    #英文:索引到单词
18 inv_cn_dict = {v: k for k, v in cn_dict.items()}    #中文:索引到字

1.3把单词全部转变成数字

sort_by_len=True的目的是为了使得一个batch中的句子长度差不多,所以按长度排序。

 1 def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):        
 2 
 3     length = len(en_sentences)
 4     out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
 5     out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences]
 6 
 7     # sort sentences by word lengths
 8     def len_argsort(seq):
 9         return sorted(range(len(seq)), key=lambda x: len(seq[x]))
10        
11     # 把中文和英文按照同样的顺序排序
12     if sort_by_len:
13         sorted_index = len_argsort(out_en_sentences)
14         out_en_sentences = [out_en_sentences[i] for i in sorted_index]
15         out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
16         
17     return out_en_sentences, out_cn_sentences
18 
19 train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
20 dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)

查看返回的数据内容:

1 print(train_cn[2])
2 print([inv_cn_dict[i] for i in train_cn[2]])
3 print([inv_en_dict[i] for i in train_en[2]])

[2, 982, 2028, 8, 4, 3]

['BOS', '祝', '贺', '你', '。', 'EOS']

['BOS', 'congratulations', '!', 'EOS']

1.4把全部句子分成batch

1 def get_minibatches(n, minibatch_size, shuffle=True):  #n是传进来的句子数
2     idx_list = np.arange(0, n, minibatch_size)   #[0, 1, ..., n-1]按minibatch_size大小分割
3     if shuffle:
4         np.random.shuffle(idx_list)
5     minibatches = []
6     for idx in idx_list:
7         minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
8     return minibatches

查看上面函数的功能:

1 get_minibatches(100, 15)
2 [array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]),
3  array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]),
4  array([75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]),
5  array([45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]),
6  array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]),
7  array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
8  array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])]
 1 def prepare_data(seqs):   #seqs传入的是minibatches中的一个minibatch对应的batch_size个句子索引(嵌套列表),此处batch_size=64
 2 
 3     lengths = [len(seq) for seq in seqs]   
 4     n_samples = len(seqs)            
 5     max_len = np.max(lengths)  #batch_size个句子中最长句子长度
 6 
 7     x = np.zeros((n_samples, max_len)).astype('int32')
 8     x_lengths = np.array(lengths).astype("int32")
 9     for idx, seq in enumerate(seqs):
10         x[idx, :lengths[idx]] = seq
11     return x, x_lengths             
12 
13 def gen_examples(en_sentences, cn_sentences, batch_size):
14     minibatches = get_minibatches(len(en_sentences), batch_size)
15     all_ex = []
16     for minibatch in minibatches:
17         mb_en_sentences = [en_sentences[t] for t in minibatch]
18         mb_cn_sentences = [cn_sentences[t] for t in minibatch]
19         mb_x, mb_x_len = prepare_data(mb_en_sentences)
20         mb_y, mb_y_len = prepare_data(mb_cn_sentences)
21         all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
22     return all_ex     #返回内容依次是batch_size个英文句子索引,英文句子长度,中文句子索引,中文句子长度
23 
24 batch_size = 64
25 train_data = gen_examples(train_en, train_cn, batch_size)
26 dev_data = gen_examples(dev_en, dev_cn, batch_size)

2.Encoder Decoder模型(没有Attention版本)

2.1定义计算损失的函数

 1 # masked cross entropy loss
 2 class LanguageModelCriterion(nn.Module):
 3     def __init__(self):
 4         super(LanguageModelCriterion, self).__init__()
 5 
 6     def forward(self, input, target, mask):   #把mask的部分忽略掉
 7         # input: (batch_size * seq_len) * vocab_size
 8         input = input.contiguous().view(-1, input.size(2))
 9         # target: batch_size * 1
10         target = target.contiguous().view(-1, 1)
11         mask = mask.contiguous().view(-1, 1)
12         output = -input.gather(1, target) * mask
13         output = torch.sum(output) / torch.sum(mask)
14 
15         return output

2.2Encoder部分

Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors;

对nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence的理解:http://www.mamicode.com/info-detail-2493083.html

 1 class PlainEncoder(nn.Module):
 2     def __init__(self, vocab_size, hidden_size, dropout=0.2):       #假设embedding_size=hidden_size
 3         super(PlainEncoder, self).__init__()
 4         self.embed = nn.Embedding(vocab_size, hidden_size)
 5         self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
 6         self.dropout = nn.Dropout(dropout)
 7 
 8     def forward(self, x, lengths):   #最后一个hidden_state要取出来作为context vector,所以需要lengths
 9         sorted_len, sorted_idx = lengths.sort(0, descending=True)   #把batch里面的seq按照长度降序排列
10         x_sorted = x[sorted_idx.long()]
11         embedded = self.dropout(self.embed(x_sorted))
12         
13         #句子padding到一样长度的(真实句长会比padding的短),为了rnn时能取到真实长度的最后状态,先pack_padded_sequence进行处理
14         packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
15         packed_out, hid = self.rnn(packed_embedded)
16         out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)    #回到padding长度
17         
18         _, original_idx = sorted_idx.sort(0, descending=False)                     #排序回原来的样子
19         out = out[original_idx.long()].contiguous()
20         hid = hid[:, original_idx.long()].contiguous()
21         
22         return out, hid[[-1]]   #hid[[-1]]相当于out[:, -1]

2.3Decoder部分

Decoder会根据已经翻译的句子内容和context vectors,来决定下一个输出的单词;

 1 class PlainDecoder(nn.Module):
 2     def __init__(self, vocab_size, hidden_size, dropout=0.2):
 3         super(PlainDecoder, self).__init__()
 4         self.embed = nn.Embedding(vocab_size, hidden_size)
 5         self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
 6         self.fc = nn.Linear(hidden_size, vocab_size)
 7         self.dropout = nn.Dropout(dropout)
 8         
 9     def forward(self, y, y_lengths, hid):    #和PlainEncoder的forward过程大致差不多,区别在于hidden_state不是0而是传入的
10         sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
11         y_sorted = y[sorted_idx.long()]
12         hid = hid[:, sorted_idx.long()]
13 
14         y_sorted = self.dropout(self.embed(y_sorted))             #[batch_size, y_lengths, embed_size=hidden_size]
15         
16         packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
17         out, hid = self.rnn(packed_seq, hid)
18         unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
19 
20         _, original_idx = sorted_idx.sort(0, descending=False)
21         output_seq = unpacked[original_idx.long()].contiguous()   #[batch_size, y_lengths, hidden_size]
22         hid = hid[:, original_idx.long()].contiguous()            #[1, batch_size, hidden_size]
23 
24         output = F.log_softmax(self.fc(output_seq), -1)           #[batch_size, y_lengths, vocab_size]
25         
26         return output, hid

2.4构建Seq2Seq模型

构建Seq2Seq模型把encoder, attention, decoder串到一起;

 1 class PlainSeq2Seq(nn.Module):
 2     def __init__(self, encoder, decoder):
 3         super(PlainSeq2Seq, self).__init__()
 4         self.encoder = encoder
 5         self.decoder = decoder
 6         
 7     def forward(self, x, x_lengths, y, y_lengths):
 8         encoder_out, hid = self.encoder(x, x_lengths)
 9         output, hid = self.decoder(y, y_lengths, hid)
10         return output, None
11     
12     def translate(self, x, x_lengths, y, max_length=10):
13         encoder_out, hid = self.encoder(x, x_lengths)
14         preds = []
15         batch_size = x.shape[0]
16         attns = []
17         for i in range(max_length):
18             output, hid = self.decoder(y=y, y_lengths=torch.ones(batch_size).long().to(y.device), hid=hid)
19             y = output.max(2)[1].view(batch_size, 1)
20             preds.append(y)
21             
22         return torch.cat(preds, 1), None

2.5定义损失函数

 1 # masked cross entropy loss
 2 class LanguageModelCriterion(nn.Module):
 3     def __init__(self):
 4         super(LanguageModelCriterion, self).__init__()
 5 
 6     def forward(self, input, target, mask):
 7         # input: (batch_size * seq_len) * vocab_size
 8         input = input.contiguous().view(-1, input.size(2))
 9         # target: batch_size * 1
10         target = target.contiguous().view(-1, 1)
11         mask = mask.contiguous().view(-1, 1)
12         output = -input.gather(1, target) * mask
13         output = torch.sum(output) / torch.sum(mask)
14 
15         return output

3.创建模型

定义模型、损失、优化器。

1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2 dropout = 0.2
3 hidden_size = 100
4 encoder = PlainEncoder(vocab_size=en_total_words, hidden_size=hidden_size, dropout=dropout)
5 decoder = PlainDecoder(vocab_size=cn_total_words, hidden_size=hidden_size, dropout=dropout)
6 model = PlainSeq2Seq(encoder, decoder)
7 model = model.to(device)
8 loss_fn = LanguageModelCriterion().to(device)
9 optimizer = torch.optim.Adam(model.parameters())

4.评估模型

 1 def evaluate(model, data):
 2     model.eval()
 3     total_num_words = total_loss = 0.
 4     with torch.no_grad():
 5         for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
 6             mb_x = torch.from_numpy(mb_x).to(device).long()
 7             mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
 8             mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
 9             mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
10             mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
11             mb_y_len[mb_y_len<=0] = 1
12 
13             mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
14 
15             mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
16             mb_out_mask = mb_out_mask.float()
17 
18             loss = loss_fn(mb_pred, mb_output, mb_out_mask)
19 
20             num_words = torch.sum(mb_y_len).item()
21             total_loss += loss.item() * num_words
22             total_num_words += num_words
23     print("Evaluation loss", total_loss/total_num_words)

5.训练模型

 1 def train(model, data, num_epochs=20):
 2     for epoch in range(num_epochs):
 3         model.train()
 4         total_num_words = total_loss = 0.
 5         for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
 6             mb_x = torch.from_numpy(mb_x).to(device).long()
 7             mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
 8             mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
 9             mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
10             mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
11             mb_y_len[mb_y_len<=0] = 1
12             
13             mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
14             
15             mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
16             mb_out_mask = mb_out_mask.float()
17             
18             loss = loss_fn(mb_pred, mb_output, mb_out_mask)
19             
20             num_words = torch.sum(mb_y_len).item()
21             total_loss += loss.item() * num_words
22             total_num_words += num_words
23             
24             # 更新模型
25             optimizer.zero_grad()
26             loss.backward()
27             torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
28             optimizer.step()
29             
30             if it % 100 == 0:
31                 print("Epoch", epoch, "iteration", it, "loss", loss.item())
32 
33                 
34         print("Epoch", epoch, "Training loss", total_loss/total_num_words)
35         if epoch % 5 == 0:
36             evaluate(model, dev_data)

训练100次:

1 train(model, train_data, num_epochs=100)

训练结果(training loss在不断下降):

  1 Epoch 0 iteration 0 loss 8.084440231323242
  2 Epoch 0 iteration 100 loss 4.8448944091796875
  3 Epoch 0 iteration 200 loss 4.879772663116455
  4 Epoch 0 Training loss 5.477221919210141
  5 Evaluation loss 4.821030395389826
  6 Epoch 1 iteration 0 loss 4.69868278503418
  7 Epoch 1 iteration 100 loss 4.085171699523926
  8 Epoch 1 iteration 200 loss 4.312857151031494
  9 Epoch 1 Training loss 4.579521701350524
 10 Epoch 2 iteration 0 loss 4.193971633911133
 11 Epoch 2 iteration 100 loss 3.678673267364502
 12 Epoch 2 iteration 200 loss 4.019515514373779
 13 Epoch 2 Training loss 4.186071368925457
 14 Epoch 3 iteration 0 loss 3.8352835178375244
 15 Epoch 3 iteration 100 loss 3.3954527378082275
 16 Epoch 3 iteration 200 loss 3.774580240249634
 17 Epoch 3 Training loss 3.9222166424267986
 18 Epoch 4 iteration 0 loss 3.585063934326172
 19 Epoch 4 iteration 100 loss 3.215750217437744
 20 Epoch 4 iteration 200 loss 3.626997232437134
 21 Epoch 4 Training loss 3.722608096150466
 22 Epoch 5 iteration 0 loss 3.411375045776367
 23 Epoch 5 iteration 100 loss 3.0424859523773193
 24 Epoch 5 iteration 200 loss 3.492255926132202
 25 Epoch 5 Training loss 3.5699179079587195
 26 Evaluation loss 3.655821240952787
 27 Epoch 6 iteration 0 loss 3.273927927017212
 28 Epoch 6 iteration 100 loss 2.897022247314453
 29 Epoch 6 iteration 200 loss 3.355715036392212
 30 Epoch 6 Training loss 3.4411540739967426
 31 Epoch 7 iteration 0 loss 3.16508412361145
 32 Epoch 7 iteration 100 loss 2.7818763256073
 33 Epoch 7 iteration 200 loss 3.241000175476074
 34 Epoch 7 Training loss 3.330995073153501
 35 Epoch 8 iteration 0 loss 3.081458806991577
 36 Epoch 8 iteration 100 loss 2.692844867706299
 37 Epoch 8 iteration 200 loss 3.159105062484741
 38 Epoch 8 Training loss 3.237538761219645
 39 Epoch 9 iteration 0 loss 2.983361005783081
 40 Epoch 9 iteration 100 loss 2.5852301120758057
 41 Epoch 9 iteration 200 loss 3.076793670654297
 42 Epoch 9 Training loss 3.1542968146839754
 43 Epoch 10 iteration 0 loss 2.88155198097229
 44 Epoch 10 iteration 100 loss 2.504387617111206
 45 Epoch 10 iteration 200 loss 2.9708898067474365
 46 Epoch 10 Training loss 3.0766581801071924
 47 Evaluation loss 3.3804360915245204
 48 Epoch 11 iteration 0 loss 2.805739164352417
 49 Epoch 11 iteration 100 loss 2.417832612991333
 50 Epoch 11 iteration 200 loss 2.9001076221466064
 51 Epoch 11 Training loss 3.0072335865815747
 52 Epoch 12 iteration 0 loss 2.7389864921569824
 53 Epoch 12 iteration 100 loss 2.352132558822632
 54 Epoch 12 iteration 200 loss 2.864527702331543
 55 Epoch 12 Training loss 2.945309993148362
 56 Epoch 13 iteration 0 loss 2.6841001510620117
 57 Epoch 13 iteration 100 loss 2.2722346782684326
 58 Epoch 13 iteration 200 loss 2.8002915382385254
 59 Epoch 13 Training loss 2.8879525671218156
 60 Epoch 14 iteration 0 loss 2.641491651535034
 61 Epoch 14 iteration 100 loss 2.237807273864746
 62 Epoch 14 iteration 200 loss 2.7538034915924072
 63 Epoch 14 Training loss 2.833802188663957
 64 Epoch 15 iteration 0 loss 2.5613601207733154
 65 Epoch 15 iteration 100 loss 2.149299144744873
 66 Epoch 15 iteration 200 loss 2.671037435531616
 67 Epoch 15 Training loss 2.7850014679518598
 68 Evaluation loss 3.2569677577366516
 69 Epoch 16 iteration 0 loss 2.5330140590667725
 70 Epoch 16 iteration 100 loss 2.0988974571228027
 71 Epoch 16 iteration 200 loss 2.611022472381592
 72 Epoch 16 Training loss 2.7354116963192716
 73 Epoch 17 iteration 0 loss 2.485084295272827
 74 Epoch 17 iteration 100 loss 2.0532665252685547
 75 Epoch 17 iteration 200 loss 2.604226589202881
 76 Epoch 17 Training loss 2.6934350694497957
 77 Epoch 18 iteration 0 loss 2.4521820545196533
 78 Epoch 18 iteration 100 loss 2.0395381450653076
 79 Epoch 18 iteration 200 loss 2.5578808784484863
 80 Epoch 18 Training loss 2.651303096776386
 81 Epoch 19 iteration 0 loss 2.390338182449341
 82 Epoch 19 iteration 100 loss 1.9780246019363403
 83 Epoch 19 iteration 200 loss 2.5150232315063477
 84 Epoch 19 Training loss 2.611681331448251
 85 Epoch 20 iteration 0 loss 2.352649211883545
 86 Epoch 20 iteration 100 loss 1.9426053762435913
 87 Epoch 20 iteration 200 loss 2.4782586097717285
 88 Epoch 20 Training loss 2.5747013451744616
 89 Evaluation loss 3.194680030596711
 90 Epoch 21 iteration 0 loss 2.3205008506774902
 91 Epoch 21 iteration 100 loss 1.9143742322921753
 92 Epoch 21 iteration 200 loss 2.4607479572296143
 93 Epoch 21 Training loss 2.5404243457594116
 94 Epoch 22 iteration 0 loss 2.3100969791412354
 95 Epoch 22 iteration 100 loss 1.912932276725769
 96 Epoch 22 iteration 200 loss 2.4103682041168213
 97 Epoch 22 Training loss 2.507626390779296
 98 Epoch 23 iteration 0 loss 2.228956699371338
 99 Epoch 23 iteration 100 loss 1.8543353080749512
100 Epoch 23 iteration 200 loss 2.3663489818573
101 Epoch 23 Training loss 2.475231424650597
102 Epoch 24 iteration 0 loss 2.199277639389038
103 Epoch 24 iteration 100 loss 1.8272788524627686
104 Epoch 24 iteration 200 loss 2.3518714904785156
105 Epoch 24 Training loss 2.4439996520576863
106 Epoch 25 iteration 0 loss 2.198460817337036
107 Epoch 25 iteration 100 loss 1.7921738624572754
108 Epoch 25 iteration 200 loss 2.3299384117126465
109 Epoch 25 Training loss 2.416539151404694
110 Evaluation loss 3.1583419660450347
111 Epoch 26 iteration 0 loss 2.1647706031799316
112 Epoch 26 iteration 100 loss 1.725657343864441
113 Epoch 26 iteration 200 loss 2.268852710723877
114 Epoch 26 Training loss 2.3919890312051444
115 Epoch 27 iteration 0 loss 2.1400880813598633
116 Epoch 27 iteration 100 loss 1.7474910020828247
117 Epoch 27 iteration 200 loss 2.256742000579834
118 Epoch 27 Training loss 2.3595162004913086
119 Epoch 28 iteration 0 loss 2.0979115962982178
120 Epoch 28 iteration 100 loss 1.7000322341918945
121 Epoch 28 iteration 200 loss 2.2546005249023438
122 Epoch 28 Training loss 2.3335356415568618
123 Epoch 29 iteration 0 loss 2.1031572818756104
124 Epoch 29 iteration 100 loss 1.6599613428115845
125 Epoch 29 iteration 200 loss 2.2020833492279053
126 Epoch 29 Training loss 2.311978717884133
127 Epoch 30 iteration 0 loss 2.041980028152466
128 Epoch 30 iteration 100 loss 1.6663353443145752
129 Epoch 30 iteration 200 loss 2.1463098526000977
130 Epoch 30 Training loss 2.2902015222655807
131 Evaluation loss 3.133273747140961
132 Epoch 31 iteration 0 loss 2.0045719146728516
133 Epoch 31 iteration 100 loss 1.6515719890594482
134 Epoch 31 iteration 200 loss 2.1130664348602295
135 Epoch 31 Training loss 2.2633183437027657
136 Epoch 32 iteration 0 loss 1.9948643445968628
137 Epoch 32 iteration 100 loss 1.6262538433074951
138 Epoch 32 iteration 200 loss 2.1329450607299805
139 Epoch 32 Training loss 2.242057023454951
140 Epoch 33 iteration 0 loss 1.9623773097991943
141 Epoch 33 iteration 100 loss 1.6022558212280273
142 Epoch 33 iteration 200 loss 2.092766523361206
143 Epoch 33 Training loss 2.219300144243463
144 Epoch 34 iteration 0 loss 1.929176688194275
145 Epoch 34 iteration 100 loss 1.57985258102417
146 Epoch 34 iteration 200 loss 2.067972183227539
147 Epoch 34 Training loss 2.199957146669663
148 Epoch 35 iteration 0 loss 1.9449653625488281
149 Epoch 35 iteration 100 loss 1.5760831832885742
150 Epoch 35 iteration 200 loss 2.056731939315796
151 Epoch 35 Training loss 2.1790822226814464
152 Evaluation loss 3.13363336627263
153 Epoch 36 iteration 0 loss 1.8961074352264404
154 Epoch 36 iteration 100 loss 1.5195672512054443
155 Epoch 36 iteration 200 loss 2.0268213748931885
156 Epoch 36 Training loss 2.160204240618562
157 Epoch 37 iteration 0 loss 1.9172203540802002
158 Epoch 37 iteration 100 loss 1.495902180671692
159 Epoch 37 iteration 200 loss 1.9827772378921509
160 Epoch 37 Training loss 2.139063811380212
161 Epoch 38 iteration 0 loss 1.8988227844238281
162 Epoch 38 iteration 100 loss 1.5224453210830688
163 Epoch 38 iteration 200 loss 1.972291111946106
164 Epoch 38 Training loss 2.1211086652629887
165 Epoch 39 iteration 0 loss 1.8728121519088745
166 Epoch 39 iteration 100 loss 1.4476994276046753
167 Epoch 39 iteration 200 loss 1.9898269176483154
168 Epoch 39 Training loss 2.1024907934743258
169 Epoch 40 iteration 0 loss 1.8664008378982544
170 Epoch 40 iteration 100 loss 1.4997611045837402
171 Epoch 40 iteration 200 loss 1.9541966915130615
172 Epoch 40 Training loss 2.086313187411815
173 Evaluation loss 3.1282314096494708
174 Epoch 41 iteration 0 loss 1.865237832069397
175 Epoch 41 iteration 100 loss 1.4755399227142334
176 Epoch 41 iteration 200 loss 1.9337103366851807
177 Epoch 41 Training loss 2.068258631932244
178 Epoch 42 iteration 0 loss 1.790804147720337
179 Epoch 42 iteration 100 loss 1.4380069971084595
180 Epoch 42 iteration 200 loss 1.9523491859436035
181 Epoch 42 Training loss 2.0498001934027874
182 Epoch 43 iteration 0 loss 1.7979768514633179
183 Epoch 43 iteration 100 loss 1.436006784439087
184 Epoch 43 iteration 200 loss 1.9101322889328003
185 Epoch 43 Training loss 2.0354298580230195
186 Epoch 44 iteration 0 loss 1.7717180252075195
187 Epoch 44 iteration 100 loss 1.412601351737976
188 Epoch 44 iteration 200 loss 1.8883790969848633
189 Epoch 44 Training loss 2.0182710578663032
190 Epoch 45 iteration 0 loss 1.7614871263504028
191 Epoch 45 iteration 100 loss 1.3429900407791138
192 Epoch 45 iteration 200 loss 1.862486720085144
193 Epoch 45 Training loss 2.0034489605129595
194 Evaluation loss 3.13050353642062
195 Epoch 46 iteration 0 loss 1.753187656402588
196 Epoch 46 iteration 100 loss 1.3810824155807495
197 Epoch 46 iteration 200 loss 1.8526273965835571
198 Epoch 46 Training loss 1.9899710891643612
199 Epoch 47 iteration 0 loss 1.7567869424819946
200 Epoch 47 iteration 100 loss 1.3430988788604736
201 Epoch 47 iteration 200 loss 1.8135911226272583
202 Epoch 47 Training loss 1.9723690433387957
203 Epoch 48 iteration 0 loss 1.7263280153274536
204 Epoch 48 iteration 100 loss 1.3430798053741455
205 Epoch 48 iteration 200 loss 1.8229252099990845
206 Epoch 48 Training loss 1.9580909331705005
207 Epoch 49 iteration 0 loss 1.731834888458252
208 Epoch 49 iteration 100 loss 1.325390100479126
209 Epoch 49 iteration 200 loss 1.8075029850006104
210 Epoch 49 Training loss 1.9418853706725143
211 Epoch 50 iteration 0 loss 1.7218893766403198
212 Epoch 50 iteration 100 loss 1.2710607051849365
213 Epoch 50 iteration 200 loss 1.8196479082107544
214 Epoch 50 Training loss 1.9300463292027463
215 Evaluation loss 3.1402900424368902
216 Epoch 51 iteration 0 loss 1.701721429824829
217 Epoch 51 iteration 100 loss 1.2720820903778076
218 Epoch 51 iteration 200 loss 1.7759710550308228
219 Epoch 51 Training loss 1.9192517232508806
220 Epoch 52 iteration 0 loss 1.7286512851715088
221 Epoch 52 iteration 100 loss 1.2737478017807007
222 Epoch 52 iteration 200 loss 1.7545547485351562
223 Epoch 52 Training loss 1.906238278183267
224 Epoch 53 iteration 0 loss 1.6672327518463135
225 Epoch 53 iteration 100 loss 1.3138436079025269
226 Epoch 53 iteration 200 loss 1.8045201301574707
227 Epoch 53 Training loss 1.8922825534741075
228 Epoch 54 iteration 0 loss 1.617557168006897
229 Epoch 54 iteration 100 loss 1.22885262966156
230 Epoch 54 iteration 200 loss 1.7750707864761353
231 Epoch 54 Training loss 1.8807705430479014
232 Epoch 55 iteration 0 loss 1.66348135471344
233 Epoch 55 iteration 100 loss 1.2331219911575317
234 Epoch 55 iteration 200 loss 1.7303975820541382
235 Epoch 55 Training loss 1.867195544079556
236 Evaluation loss 3.145431456349013
237 Epoch 56 iteration 0 loss 1.6259342432022095
238 Epoch 56 iteration 100 loss 1.2141388654708862
239 Epoch 56 iteration 200 loss 1.6984847784042358
240 Epoch 56 Training loss 1.8548133653506713
241 Epoch 57 iteration 0 loss 1.605487585067749
242 Epoch 57 iteration 100 loss 1.1920335292816162
243 Epoch 57 iteration 200 loss 1.7253336906433105
244 Epoch 57 Training loss 1.8387836396466541
245 Epoch 58 iteration 0 loss 1.600136160850525
246 Epoch 58 iteration 100 loss 1.2192472219467163
247 Epoch 58 iteration 200 loss 1.6888371706008911
248 Epoch 58 Training loss 1.83046734055076
249 Epoch 59 iteration 0 loss 1.6042535305023193
250 Epoch 59 iteration 100 loss 1.2362377643585205
251 Epoch 59 iteration 200 loss 1.6654771566390991
252 Epoch 59 Training loss 1.8226244935892273
253 Epoch 60 iteration 0 loss 1.5602766275405884
254 Epoch 60 iteration 100 loss 1.201045036315918
255 Epoch 60 iteration 200 loss 1.6702684164047241
256 Epoch 60 Training loss 1.8102721190615219
257 Evaluation loss 3.154303393916162
258 Epoch 61 iteration 0 loss 1.5679781436920166
259 Epoch 61 iteration 100 loss 1.2105367183685303
260 Epoch 61 iteration 200 loss 1.6650742292404175
261 Epoch 61 Training loss 1.7970227477404426
262 Epoch 62 iteration 0 loss 1.5734565258026123
263 Epoch 62 iteration 100 loss 1.1602052450180054
264 Epoch 62 iteration 200 loss 1.583187222480774
265 Epoch 62 Training loss 1.787027303402099
266 Epoch 63 iteration 0 loss 1.563283920288086
267 Epoch 63 iteration 100 loss 1.1829460859298706
268 Epoch 63 iteration 200 loss 1.6458944082260132
269 Epoch 63 Training loss 1.7742324239103342
270 Epoch 64 iteration 0 loss 1.5429617166519165
271 Epoch 64 iteration 100 loss 1.1225509643554688
272 Epoch 64 iteration 200 loss 1.6353931427001953
273 Epoch 64 Training loss 1.7665018986396424
274 Epoch 65 iteration 0 loss 1.5284583568572998
275 Epoch 65 iteration 100 loss 1.1426113843917847
276 Epoch 65 iteration 200 loss 1.6138485670089722
277 Epoch 65 Training loss 1.7557591437816458
278 Evaluation loss 3.166533922994568
279 Epoch 66 iteration 0 loss 1.5184751749038696
280 Epoch 66 iteration 100 loss 1.127056360244751
281 Epoch 66 iteration 200 loss 1.611910343170166
282 Epoch 66 Training loss 1.7446940747065838
283 Epoch 67 iteration 0 loss 1.4880752563476562
284 Epoch 67 iteration 100 loss 1.1075133085250854
285 Epoch 67 iteration 200 loss 1.6138321161270142
286 Epoch 67 Training loss 1.7374662356132202
287 Epoch 68 iteration 0 loss 1.5260978937149048
288 Epoch 68 iteration 100 loss 1.12235689163208
289 Epoch 68 iteration 200 loss 1.6129950284957886
290 Epoch 68 Training loss 1.7253250324901928
291 Epoch 69 iteration 0 loss 1.5172449350357056
292 Epoch 69 iteration 100 loss 1.1174883842468262
293 Epoch 69 iteration 200 loss 1.551174283027649
294 Epoch 69 Training loss 1.7166664929363027
295 Epoch 70 iteration 0 loss 1.5006300210952759
296 Epoch 70 iteration 100 loss 1.0905342102050781
297 Epoch 70 iteration 200 loss 1.5446460247039795
298 Epoch 70 Training loss 1.70989819337649
299 Evaluation loss 3.1750113054724385
300 Epoch 71 iteration 0 loss 1.4726097583770752
301 Epoch 71 iteration 100 loss 1.086822509765625
302 Epoch 71 iteration 200 loss 1.5575647354125977
303 Epoch 71 Training loss 1.697000935158525
304 Epoch 72 iteration 0 loss 1.449334979057312
305 Epoch 72 iteration 100 loss 1.0667144060134888
306 Epoch 72 iteration 200 loss 1.530726671218872
307 Epoch 72 Training loss 1.6881878283419123
308 Epoch 73 iteration 0 loss 1.4603246450424194
309 Epoch 73 iteration 100 loss 1.0751914978027344
310 Epoch 73 iteration 200 loss 1.5088605880737305
311 Epoch 73 Training loss 1.6805761044806562
312 Epoch 74 iteration 0 loss 1.4748084545135498
313 Epoch 74 iteration 100 loss 1.0556395053863525
314 Epoch 74 iteration 200 loss 1.5206905603408813
315 Epoch 74 Training loss 1.6673887956853506
316 Epoch 75 iteration 0 loss 1.454646348953247
317 Epoch 75 iteration 100 loss 1.0396276712417603
318 Epoch 75 iteration 200 loss 1.518398404121399
319 Epoch 75 Training loss 1.6633919350661184
320 Evaluation loss 3.189181657332237
321 Epoch 76 iteration 0 loss 1.4616646766662598
322 Epoch 76 iteration 100 loss 0.9838554859161377
323 Epoch 76 iteration 200 loss 1.4613702297210693
324 Epoch 76 Training loss 1.6526747506920867
325 Epoch 77 iteration 0 loss 1.4646761417388916
326 Epoch 77 iteration 100 loss 1.0383753776550293
327 Epoch 77 iteration 200 loss 1.5081768035888672
328 Epoch 77 Training loss 1.6462943129725018
329 Epoch 78 iteration 0 loss 1.4008097648620605
330 Epoch 78 iteration 100 loss 1.0147686004638672
331 Epoch 78 iteration 200 loss 1.5017434358596802
332 Epoch 78 Training loss 1.6352284007247493
333 Epoch 79 iteration 0 loss 1.4189144372940063
334 Epoch 79 iteration 100 loss 1.0126101970672607
335 Epoch 79 iteration 200 loss 1.4195480346679688
336 Epoch 79 Training loss 1.628015456811747
337 Epoch 80 iteration 0 loss 1.4199804067611694
338 Epoch 80 iteration 100 loss 1.0256879329681396
339 Epoch 80 iteration 200 loss 1.4564563035964966
340 Epoch 80 Training loss 1.6227562783981957
341 Evaluation loss 3.2074876046135703
342 Epoch 81 iteration 0 loss 1.431972622871399
343 Epoch 81 iteration 100 loss 1.0110960006713867
344 Epoch 81 iteration 200 loss 1.4414775371551514
345 Epoch 81 Training loss 1.6157781071711008
346 Epoch 82 iteration 0 loss 1.4158073663711548
347 Epoch 82 iteration 100 loss 0.9702512621879578
348 Epoch 82 iteration 200 loss 1.4209394454956055
349 Epoch 82 Training loss 1.605166310639776
350 Epoch 83 iteration 0 loss 1.3871146440505981
351 Epoch 83 iteration 100 loss 1.0183656215667725
352 Epoch 83 iteration 200 loss 1.4292359352111816
353 Epoch 83 Training loss 1.5961119023327037
354 Epoch 84 iteration 0 loss 1.3919366598129272
355 Epoch 84 iteration 100 loss 0.9692129492759705
356 Epoch 84 iteration 200 loss 1.4092985391616821
357 Epoch 84 Training loss 1.5897755956223851
358 Epoch 85 iteration 0 loss 1.355398416519165
359 Epoch 85 iteration 100 loss 0.9916797280311584
360 Epoch 85 iteration 200 loss 1.423561453819275
361 Epoch 85 Training loss 1.5878568289810793
362 Evaluation loss 3.2138472480503295
363 Epoch 86 iteration 0 loss 1.351928472518921
364 Epoch 86 iteration 100 loss 0.9997824430465698
365 Epoch 86 iteration 200 loss 1.4049323797225952
366 Epoch 86 Training loss 1.5719682346027806
367 Epoch 87 iteration 0 loss 1.3508714437484741
368 Epoch 87 iteration 100 loss 0.9411044716835022
369 Epoch 87 iteration 200 loss 1.4019731283187866
370 Epoch 87 Training loss 1.5641802139809575
371 Epoch 88 iteration 0 loss 1.347946047782898
372 Epoch 88 iteration 100 loss 0.9493017792701721
373 Epoch 88 iteration 200 loss 1.3770906925201416
374 Epoch 88 Training loss 1.5587840858982533
375 Epoch 89 iteration 0 loss 1.320084571838379
376 Epoch 89 iteration 100 loss 0.9223963022232056
377 Epoch 89 iteration 200 loss 1.4065088033676147
378 Epoch 89 Training loss 1.5548267858027334
379 Epoch 90 iteration 0 loss 1.3534889221191406
380 Epoch 90 iteration 100 loss 0.9281108975410461
381 Epoch 90 iteration 200 loss 1.3821330070495605
382 Epoch 90 Training loss 1.5474867314671616
383 Evaluation loss 3.2276618163204667
384 Epoch 91 iteration 0 loss 1.3667511940002441
385 Epoch 91 iteration 100 loss 0.8797598481178284
386 Epoch 91 iteration 200 loss 1.3776274919509888
387 Epoch 91 Training loss 1.536482189982952
388 Epoch 92 iteration 0 loss 1.3355433940887451
389 Epoch 92 iteration 100 loss 0.9130176901817322
390 Epoch 92 iteration 200 loss 1.3042923212051392
391 Epoch 92 Training loss 1.5308507835779057
392 Epoch 93 iteration 0 loss 1.2953367233276367
393 Epoch 93 iteration 100 loss 0.9194003939628601
394 Epoch 93 iteration 200 loss 1.3469970226287842
395 Epoch 93 Training loss 1.519625581403501
396 Epoch 94 iteration 0 loss 1.322600245475769
397 Epoch 94 iteration 100 loss 0.9003701210021973
398 Epoch 94 iteration 200 loss 1.3512846231460571
399 Epoch 94 Training loss 1.5193673748787049
400 Epoch 95 iteration 0 loss 1.2789180278778076
401 Epoch 95 iteration 100 loss 0.9352515339851379
402 Epoch 95 iteration 200 loss 1.3609877824783325
403 Epoch 95 Training loss 1.5135782739054082
404 Evaluation loss 3.2474015759319284
405 Epoch 96 iteration 0 loss 1.3051612377166748
406 Epoch 96 iteration 100 loss 0.8885603547096252
407 Epoch 96 iteration 200 loss 1.3272497653961182
408 Epoch 96 Training loss 1.5079536183100883
409 Epoch 97 iteration 0 loss 1.2671339511871338
410 Epoch 97 iteration 100 loss 0.8706735968589783
411 Epoch 97 iteration 200 loss 1.305412769317627
412 Epoch 97 Training loss 1.4974833326540824
413 Epoch 98 iteration 0 loss 1.308292269706726
414 Epoch 98 iteration 100 loss 0.9079441428184509
415 Epoch 98 iteration 200 loss 1.2940715551376343
416 Epoch 98 Training loss 1.4928753682563118
417 Epoch 99 iteration 0 loss 1.276250958442688
418 Epoch 99 iteration 100 loss 0.890657901763916
419 Epoch 99 iteration 200 loss 1.3286609649658203
420 Epoch 99 Training loss 1.4852960116094391
View Code

相关文章:

  • 2022-12-23
  • 2021-11-19
  • 2021-12-23
  • 2021-08-31
  • 2022-12-23
猜你喜欢
  • 2021-09-30
  • 2021-08-20
  • 2021-07-16
  • 2021-10-24
  • 2021-06-11
  • 2021-05-12
相关资源
相似解决方案