【发布时间】:2022-01-11 06:24:33
【问题描述】:
我的代码在 epoch number1 上工作得很好,但是当 epoch 发生变化时,它会因为不同的形状而停止工作。 你能帮我解决这个问题吗? 非常感谢您的宝贵时间
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainTransform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
trainset = torchvision.datasets.FashionMNIST(root='{}/./data'.format(path_prefix), train = True, download = True, transform = transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=False, num_workers=4)
valset = torchvision.datasets.FashionMNIST(root='{}/./data'.format(path_prefix), train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=False, num_workers=4)
def train(self, epoch):
# Note that you need to modify both trainer and loss_function for the VAE model
self.model.train()
train_loss = 0
for batch_idx, (data, _) in tqdm(enumerate(self.train_loader), total=len(self.train_loader) ) :
data = data.view(data.shape[0], -1)
data = data.to(self.device)
#print(data.shape)
#print(data)
self.optimizer.zero_grad()
recon_batch = self.model(data)
loss = self.loss_function(recon_batch, data)
loss.backward()
train_loss += loss.item()
self.optimizer.step()
train_loss /= len(self.train_loader.dataset)/32 # 32 is the batch size
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss ))
【问题讨论】:
标签: reshape torch autoencoder epoch mnist