【问题标题】:Debugging GAN covergence error调试 GAN 覆盖错误
【发布时间】:2018-11-18 14:56:24
【问题描述】:

构建 GAN 以生成图像。图像有 3 个颜色通道,96 x 96。

生成器一开始生成的图像都是黑色的,这是一个在统计上极不可能出现的问题。

此外,两个网络的损失都没有改善。

我在下面发布了整个代码,并进行了评论以使其易于阅读。这是我第一次构建 GAN,而且我是 Pytorch 的新手,因此非常感谢任何帮助!

谢谢。

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.autograd import Variable

import numpy as np
import os
import cv2
from collections import deque

# training params
batch_size = 100
epochs = 1000

# loss function
loss_fx = torch.nn.BCELoss()

# processing images
X = deque()
for img in os.listdir('pokemon_images'):
    if img.endswith('.png'):
        pokemon_image = cv2.imread(r'./pokemon_images/{}'.format(img))
        if pokemon_image.shape != (96, 96, 3):
            pass
        else:
            X.append(pokemon_image)

# data loader for processing in batches
data_loader = DataLoader(X, batch_size=batch_size)

# covert output vectors to images if flag is true, else input images to vectors
def images_to_vectors(data, reverse=False):
    if reverse:
        return data.view(data.size(0), 3, 96, 96)
    else:
        return data.view(data.size(0), 27648)

# Generator model
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        n_features = 1000
        n_out = 27648

        self.model = torch.nn.Sequential(
                torch.nn.Linear(n_features, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 512),
                torch.nn.ReLU(),
                torch.nn.Linear(512, 1024),
                torch.nn.ReLU(),
                torch.nn.Linear(1024, n_out),
                torch.nn.Tanh()
        )


    def forward(self, x):
        img = self.model(x)
        return img

    def noise(self, s):
       x = Variable(torch.randn(s, 1000))
       return x


# Discriminator model
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        n_features = 27648
        n_out = 1

        self.model = torch.nn.Sequential(
                torch.nn.Linear(n_features, 512),
                torch.nn.ReLU(),
                torch.nn.Linear(512, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, n_out),
                torch.nn.Sigmoid()
        )


    def forward(self, img):
        output = self.model(img)
        return output


# discriminator training
def train_discriminator(discriminator, optimizer, real_data, fake_data):
    N = real_data.size(0)
    optimizer.zero_grad()

    # train on real
    # get prediction
    pred_real = discriminator(real_data)

    # calculate loss
    error_real = loss_fx(pred_real, Variable(torch.ones(N, 1)))

    # calculate gradients
    error_real.backward()

    # train on fake
    # get prediction
    pred_fake = discriminator(fake_data)

    # calculate loss
    error_fake = loss_fx(pred_fake, Variable(torch.ones(N, 0)))

    # calculate gradients
    error_fake.backward()

    # update weights
    optimizer.step()

    return error_real + error_fake, pred_real, pred_fake


# generator training
def train_generator(generator, optimizer, fake_data):
    N = fake_data.size(0)

    # zero gradients
    optimizer.zero_grad()

    # get prediction
    pred = discriminator(generator(fake_data))

    # get loss
    error = loss_fx(pred, Variable(torch.ones(N, 0)))

    # compute gradients
    error.backward()

    # update weights
    optimizer.step()

    return error


# Instance of generator and discriminator
generator = Generator()
discriminator = Discriminator()

# optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.001)

# training loop
for epoch in range(epochs):
     for n_batch, batch in enumerate(data_loader, 0):
         N = batch.size(0)

         # Train Discriminator

         # REAL
         real_images = Variable(images_to_vectors(batch)).float()

         # FAKE
         fake_images = generator(generator.noise(N)).detach()

         # TRAIN
         d_error, d_pred_real, d_pred_fake = train_discriminator(
                 discriminator,
                 d_optimizer,
                 real_images,
                 fake_images
         )

         # Train Generator

         # generate noise
         fake_data = generator.noise(N)

         # get error based on discriminator
         g_error = train_generator(generator, g_optimizer, fake_data)

         # convert generator output to image and preprocess to show
         test_img = np.array(images_to_vectors(generator(fake_data), reverse=True).detach())
         test_img = test_img[0, :, :, :]
         test_img = test_img[..., ::-1]

         # show example of generated image
         cv2.imshow('GENERATED', test_img[0])
         if cv2.waitKey(1) & 0xFF == ord('q'):
             break

     print('EPOCH: {0}, D error: {1}, G error: {2}'.format(epoch, d_error, g_error))


cv2.destroyAllWindows()

# save weights
# torch.save('weights.pth')

【问题讨论】:

    标签: neural-network statistics deep-learning pytorch generative-adversarial-network


    【解决方案1】:

    如果没有数据等,就无法真正轻松地调试您的训练,但一个可能的问题是您的生成器的最后一层是Tanh(),这意味着输出值介于-11 之间。你可能想要:

    1. 将您的真实图像标准化到相同的范围,例如在train_discriminator()

      # train on real
      pred_real = discriminator(real_data * 2. - 1.) # supposing real_data in [0, 1]
      
    2. 在可视化/使用之前将生成的数据重新规范化为[0, 1]

      # convert generator output to image and preprocess to show
      test_img = np.array(
          images_to_vectors(generator(fake_data), reverse=True).detach())
      test_img = test_img[0, :, :, :]
      test_img = test_img[..., ::-1]
      test_img = (test_img + 1.) / 2.
      

    【讨论】:

    • 感谢您的回复!我最近尝试在最后一层将 tanh 更改为 relu 激活,结果一直在改善。更改为 relu 应该消除标准化的需要,是吗?我还将学习率大幅降低了 10^2
    • 您可能指的是sigmoid 而不是relu(因为sigmoid 的输出在[0,1] 中,与relu 不同)。我个人会保留tanh 并进行预标准化和后期标准化,因为tanh 作为一些数值优势(例如参见"tanh activation function vs sigmoid activation function")。否则,微调学习率是个好主意。
    • 是的意思是 sigmoid。感谢您的链接和帮助。
    • 你的 BCE 损失不是在生成器的输出上计算的,而是在鉴别器的输出上计算的,它已经被一个 sigmoid 标准化。
    猜你喜欢
    • 1970-01-01
    • 2019-12-31
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-09-04
    • 1970-01-01
    相关资源
    最近更新 更多