【发布时间】:2020-05-04 20:23:04
【问题描述】:
我正在为一些实验构建神经网络模型。我使用 PyTorch,每次训练模型时都使用以下代码:
def train_and_evaluate(net, optimizer, criterion):
start_time = time.time()
train_losses, test_losses, train_acc, test_acc = [], [], [], []
# net.double()
for epoch in range(num_epochs):
epoch_start_time = time.time()
running_loss_train, running_loss_test = 0.0, 0.0
total, correct, test_total, test_correct = 0, 0, 0, 0
# Train mode
net.train()
#Loop batches
for i, (images, labels) in enumerate(train_loader):
if use_cuda and torch.cuda.is_available():
images = images.cuda()
labels = labels.cuda()
optimizer.zero_grad()
outputs = net(images.detach())
# import pdb; pdb.set_trace()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss_train += loss.item()
# Calculate training accuracy for epoch
_, predicted = torch.max(outputs.data, 1) # Get th
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Calculate test loss and accuracy for epoch
net.eval()
with torch.no_grad():
for images, labels in test_loader:
if use_cuda and torch.cuda.is_available():
images = images.cuda()
labels = labels.cuda()
outputs = net(images)
test_loss = criterion(outputs, labels)
running_loss_test += test_loss.item()
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
if (epoch + 1) % 10 == 0 or epoch == 0:
print(f'Epoch [{epoch+1:02d}/{num_epochs}]\tTime: {time.time() - epoch_start_time:.2f}\tTrain Loss: {(running_loss_train / len(train_loader)):.4f}\tTrain Acc: {(correct / total):.0%}\tTest Loss: {running_loss_test / len(test_loader):.4f}\tTest Accuracy: {(test_correct / test_total):.0%}'.expandtabs(4))
train_losses.append(running_loss_train / len(train_loader))
test_losses.append(running_loss_test / len(test_loader))
train_acc.append(correct / total)
test_acc.append(test_correct / test_total)
print(f'Training time: {(time.time() - start_time)/60:5.2f} min.')
return train_losses, test_losses, train_acc, test_acc
我只是从其他项目中复制并粘贴它,我发现每次都使用相同的代码是多余的,它几乎在我测试的每个模型中都很有用,并且通常看起来非常相似。我想知道是否有实现此功能的快捷方式或某些模块来减少代码行数并使其更具可读性。比如:
model = NeuralNetwork()
train_losses, test_losses, train_acc, test_acc = train(model,
epochs=50,
verbose=True,
cuda=True,
optimizer=optimizer,
criterion=criterion)
我可以只传递参数进行训练,而不是每次都明确地“编写”代码。
【问题讨论】: