【发布时间】:2021-11-20 11:06:56
【问题描述】:
训练功能
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
optimizer.zero_grad()
output = model(batch.text)
loss = criterion(output, torch.unsqueeze(batch.labels, 1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
main_script
def main(
train_file,
test_file,
config_file,
checkpoint_path,
best_model_path
):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open(config_file, 'r') as j:
config = json.loads(j.read())
for k,v in config['model'].items():
v = float(v)
if v < 1.0:
config['model'][k] = float(v)
else:
config['model'][k] = int(v)
for k,v in config['training'].items():
v = float(v)
if v < 1.0:
config['training'][k] = float(v)
else:
config['training'][k] = int(v)
train_itr, val_itr, test_itr, vocab_size = data_pipeline(
train_file,
test_file,
config['training']['max_vocab'],
config['training']['min_freq'],
config['training']['batch_size'],
device
)
model = CNNNLPModel(
vocab_size,
config['model']['emb_dim'],
config['model']['hid_dim'],
config['model']['model_layer'],
config['model']['model_kernel_size'],
config['model']['model_dropout'],
device
)
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
num_epochs = config['training']['n_epoch']
clip = config['training']['clip']
is_best = False
best_valid_loss = float('inf')
model = model.to(device)
for epoch in tqdm(range(num_epochs)):
train_loss = train(model, train_itr, optimizer, criterion, clip)
valid_loss = evaluate(model, val_itr, criterion)
if (epoch + 1) % 2 == 0:
print("training loss {}, validation_loss{}".format(train_loss,valid_loss))
我正在训练一个用于二进制文本分类的卷积神经网络。给定一个句子,它会检测到它是否是仇恨言论。训练损失和验证损失一直很好,直到 5 epoch 之后,训练损失和验证损失突然从 0.2 猛增到 10,000。
突然损失这么大的原因可能是什么?
【问题讨论】:
-
你的渐变剪裁值是多少?
-
渐变剪裁值为 0.1
-
如果没有模型定义、数据和其他东西,很难猜出问题到底出在哪里。一个很可能的原因是你的学习率很高。顺便说一句,多少钱?另外,请尝试查看梯度和激活统计信息。
标签: deep-learning nlp pytorch conv-neural-network text-classification