【发布时间】:2020-03-14 11:56:02
【问题描述】:
我一直在 PyTorch 中使用 CNN,我需要为每个类保存图像路径及其相关的预测概率(在这种情况下,类是通过或失败)。这是我将 preds 保存到数据框的代码:
preds_df = pd.DataFrame()
class_labels = []
model_ft.eval()
for i, (inputs, labels) in enumerate(dataloaders['train']):
inputs = inputs.to(device)
labels = labels.to(device)
class_labels.append(labels.tolist())
output = model_ft(inputs)
sm = torch.nn.Softmax()
probabilities = sm(output)
arr = probabilities.data.cpu().numpy()
df = pd.DataFrame(arr)
preds_df = preds_df.append(df)
preds_df['prediction'] = preds_df.idxmax(axis=1)
class_list = [item for sublist in class_labels for item in sublist]
preds_df['label'] = class_list
preds_df.columns = ['pass (0)', 'fail (1)', 'prediction', 'label']
preds_df.to_csv('./zoom17CNN_preds.csv')
如何在数据加载器中保存每个文件的图像路径?谢谢!
【问题讨论】:
标签: python conv-neural-network pytorch