【问题标题】:model.train() and model.eval() causing nan valuesmodel.train() 和 model.eval() 导致 nan 值
【发布时间】:2021-07-14 08:31:01
【问题描述】:

嘿,所以我正在尝试使用猴子物种数据集和 resnet50 进行图像分类/迁移学习,并使用经过修改的最终 fc 层来预测 10 个类别。在我使用 model.train() 和 model.eval() 之前,一切都在工作,然后在第一个 epoch 之后它开始返回 nans 并且准确性下降,如下所示。我很好奇为什么只有在切换到 train/eval 时才会这样......?

首先我导入模型并附加分类器并冻结参数

%%capture
resnet = models.resnet50(pretrained=True)

for param in resnet.parameters():
  param.required_grad = False

in_features = resnet.fc.in_features


# Build custom classifier
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(in_features, 512)),
                                        ('relu', nn.ReLU()),
                                        ('drop', nn.Dropout(0.05)),
                                        ('fc2', nn.Linear(512, 10)),
                                        ]))

# ('output', nn.LogSoftmax(dim=1))
resnet.classifier = classifier

resnet.to(device)

然后设置我的损失函数、优化器和调度器

# Step : Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
# pass the optimizer to the appended classifier layer
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.01)
# Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.05)  

然后设置训练和验证循环

epochs = 20


tr_losses = []
avg_epoch_tr_loss = []
tr_accuracy = []


val_losses = []
avg_epoch_val_loss = []
val_accuracy = []
val_loss_min = np.Inf


resnet.train()
for epoch in range(epochs):
  for i, batch in enumerate(train_loader):
    # Pull the data and labels from the batch
    data, label = batch
    # If available push data and label to GPU
    if train_on_gpu:
      data, label = data.to(device), label.to(device)
    # Compute the logit
    logit = resnet(data)
    # Compte loss
    loss = criterion(logit, label)
    # Clearing the gradient
    resnet.zero_grad()
    # Backpropagate the gradients (accumulte the partial derivatives of loss)
    loss.backward()
    # Apply the updates to the optimizer step in the opposite direction to the gradient
    optimizer.step()
    # Store the losses of each batch
    # loss.item() seperates the loss from comp graph
    tr_losses.append(loss.item())
    # Detach and store the average accuracy of each batch
    tr_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
    # Print the rolling batch training loss every 20 batches
    if i % 40 == 0 and not i == 1:
      print(f'Batch No: {i} \tAverage Training Batch Loss: {torch.tensor(tr_losses).mean():.2f}')
  # Print the average loss for each epoch
  print(f'\nEpoch No: {epoch + 1},Training Loss: {torch.tensor(tr_losses).mean():.2f}')
  # Print the average accuracy for each epoch
  print(f'Epoch No: {epoch + 1}, Training Accuracy: {torch.tensor(tr_accuracy).mean():.2f}\n')
  # Store the avg epoch loss for plotting
  avg_epoch_tr_loss.append(torch.tensor(tr_losses).mean())


  resnet.eval()
  for i, batch in enumerate(val_loader):
    # Pull the data and labels from the batch
    data, label = batch
    # If available push data and label to GPU
    if train_on_gpu:
      data, label = data.to(device), label.to(device)
    # Compute the logits without computing the gradients
    with torch.no_grad():
      logit = resnet(data)
    # Compte loss
    loss = criterion(logit, label)
    # Store test loss
    val_losses.append(loss.item())
    # Store the accuracy for each batch
    val_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
    if i % 20 == 0 and not i == 1:
      print(f'Batch No: {i+1} \tAverage Val Batch Loss: {torch.tensor(val_losses).mean():.2f}')
  # Print the average loss for each epoch
  print(f'\nEpoch No: {epoch + 1}, Epoch Val Loss: {torch.tensor(val_losses).mean():.2f}')
  # Print the average accuracy for each epoch    
  print(f'Epoch No: {epoch + 1}, Epoch Val Accuracy: {torch.tensor(val_accuracy).mean():.2f}\n')
  # Store the avg epoch loss for plotting
  avg_epoch_val_loss.append(torch.tensor(val_losses).mean())

  # Checpoininting the model using val loss threshold
  if torch.tensor(val_losses).float().mean() <= val_loss_min:
    print("Epoch Val Loss Decreased... Saving model")
    # save current model
    torch.save(resnet.state_dict(), '/content/drive/MyDrive/1. Full Projects/Intel Image Classification/model_state.pt')
    val_loss_min = torch.tensor(val_losses).mean()
  # Step the scheduler for the next epoch
  scheduler.step()
  # Print the updated learning rate
  print('Learning Rate Set To: {:.5f}'.format(optimizer.state_dict()['param_groups'][0]['lr']),'\n')

模型开始训练,然后慢慢变成 nan 值

Batch No: 0     Average Training Batch Loss: 9.51
Batch No: 40    Average Training Batch Loss: 1.71
Batch No: 80    Average Training Batch Loss: 1.15
Batch No: 120   Average Training Batch Loss: 0.94

Epoch No: 1,Training Loss: 0.83
Epoch No: 1, Training Accuracy: 0.78

Batch No: 1     Average Val Batch Loss: 0.39
Batch No: 21    Average Val Batch Loss: 0.56
Batch No: 41    Average Val Batch Loss: 0.54
Batch No: 61    Average Val Batch Loss: 0.54

Epoch No: 1, Epoch Val Loss: 0.55
Epoch No: 1, Epoch Val Accuracy: 0.81

Epoch Val Loss Decreased... Saving model
Learning Rate Set To: 0.01000 

Batch No: 0     Average Training Batch Loss: 0.83
Batch No: 40    Average Training Batch Loss: nan
Batch No: 80    Average Training Batch Loss: nan

【问题讨论】:

  • 您的训练样本中是否有可能包含nan
  • 您可能会在this answer中找到有用的信息
  • @Shai no nans 在数据中它只是图像文件夹,我也会检查链接,谢谢。
  • 可能是应用于数据的转换创建了nans。 nan 是否总是同时/迭代/时代出现?如果显着降低学习率会怎样?
  • @Shai 我看到你在雷霍沃特,我是爱尔兰人,但我住在特拉维夫 ;)

标签: loops validation neural-network pytorch nan


【解决方案1】:

我看到resnet.zero_grad()logit = resnet(data) 之后,这会导致渐变在您的情况下爆炸。

请按如下方式进行:

# Clearing the gradient
optimizer.zero_grad()
logit = resnet(data)

# Compute loss
loss = criterion(logit, label)

【讨论】:

  • 我不确定这是这里的问题(基于 cmets,它的学习率很大)。 AFAIK 你应该在 backward() 之前 zero_grad()` 并且它是在前向传递之前还是之后都没有关系。
  • 试想一下,他正在计算前向传递中的梯度......并且他正在使用.zero_grad() 清除它们......这有意义吗?
  • 他向前,zero_grad,然后向后,一步。它看起来和 zero_grad, forward, back, step 一样好
  • 只是降低 lr 有帮助,但我也提高了 lr 在前向传球之前调用 .zero_grad() 并且两者都工作......不再有 nans。我认为在向前传球之前调用它更安全......
猜你喜欢
  • 2021-06-06
  • 2021-06-30
  • 1970-01-01
  • 1970-01-01
  • 2014-01-29
  • 2021-08-30
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多