【发布时间】:2020-09-19 20:14:27
【问题描述】:
我是 PyTorch 和对抗网络的新手。我试图在 PyTorch 文档以及 PyTorch 和 StackOverflow 论坛中的先前讨论中寻找答案,但我找不到任何有用的东西。
我正在尝试使用生成器和判别器训练 GAN,但我无法理解整个过程是否正常。就我而言,我应该首先训练生成器,然后更新鉴别器的权重(类似于this)。我更新两个模型权重的代码是:
# computing loss_g and loss_d...
optim_g.zero_grad()
loss_g.backward()
optim_g.step()
optim_d.zero_grad()
loss_d.backward()
optim_d.step()
其中loss_g 是生成器损失,loss_d 是判别器损失,optim_g 是参考生成器参数的优化器,optim_d 是判别器优化器。
如果我像这样运行代码,我会得到一个错误:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
所以我指定了loss_g.backward(retain_graph=True),我的疑问出现了:如果有两个具有两个不同图的网络,我为什么要指定retain_graph=True?我是不是搞错了?
【问题讨论】:
标签: python python-3.x pytorch torchvision