【发布时间】:2023-02-06 02:12:40
【问题描述】:
我正在尝试使用 VIT 训练模型来对一些图像进行分类。在训练过程中,脚本卡住了,我不知道我的错误在哪里。一些图像的分类特征只有两个目标 0 和 1(假和真)。批量大小为 32,epochs 只有 3。
下面我放了训练模型的脚本:
import torch.utils.data as data
from torch.autograd import Variable
import numpy as np
train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
# Train the model
for epoch in range(EPOCHS):
for step, (x, y) in enumerate(train_loader):
# Change input array into list with each batch being one element
x = np.split(np.squeeze(np.array(x)), BATCH_SIZE)
# Remove unecessary dimension
for index, array in enumerate(x):
x[index] = np.squeeze(array)
# Apply feature extractor, stack back into 1 tensor and then convert to tensor
x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
# Send to GPU if available
x = x.to(device)
y = y.to(device)
b_x = Variable(x) # batch x (image)
b_y = Variable(y) # batch y (target)
# Feed through model
output = model(b_x, None)
loss = output[0]
# Calculate loss
if loss is None:
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
# Get the next batch for testing purposes
test = next(iter(test_loader))
test_x = test[0]
# Reshape and get feature matrices as needed
test_x = np.split(np.squeeze(np.array(test_x)), BATCH_SIZE)
for index, array in enumerate(test_x):
test_x[index] = np.squeeze(array)
test_x = torch.tensor(np.stack(feature_extractor(test_x)['pixel_values'], axis=0))
# Send to appropirate computing device
test_x = test_x.to(device)
test_y = test[1].to(device)
# Get output (+ respective class) and compare to target
test_output, loss = model(test_x, test_y)
test_output = test_output.argmax(1)
# Calculate Accuracy
accuracy = (test_output == test_y).sum().item() / BATCH_SIZE
print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
错误信息是这样的:ValueError: array split does not result in a equal division。它突出显示命令x = np.split(np.squeeze(np.array(x)), BATCH_SIZE)。
【问题讨论】:
标签: python pytorch torch training-data image-classification