【问题标题】:How to retrain with weights in Keras如何在 Keras 中重新训练权重
【发布时间】:2020-02-20 11:29:57
【问题描述】:

我正在 Colab 中训练模型,但是,我关闭了计算机,训练停止了。每 5 个 epoch 我保存权重。我认为是,但我不知道如何。如何使用之前保存的权重继续训练?

谢谢。

【问题讨论】:

    标签: python keras model artificial-intelligence google-colaboratory


    【解决方案1】:

    colab 中训练模型时,训练不会在您关闭计算机时停止,而是会在一段时间后停止。

    如果您在colab 中保存权重,当colab 关闭时,所有内容都会被删除。

    如果您已将gdrive 安装在colab and 中,则您将权重保存在gdrive 中,您的权重将在那里。

    如果您的权重在您的 gdrive 中,您可以通过将存储的权重加载到您的 keras 模型来继续训练

    model.load_weights('path_to_weights')
    

    【讨论】:

    • 对不起,我不知道怎么在前面的评论中提及。
    【解决方案2】:

    感谢您的回答,@Ioannis Nasios。是的,我的体重在“gdrive”中。我正在训练一个 GAN 网络,我试图弄清楚如何加载这些权重并继续训练。我保存了鉴别器和生成器的权重以及 gan_loss 和 discriminator_loss。好吧,我是否必须编译生成器和鉴别器网络,加载权重并编译带有损失的 gan 网络?我认为这可能是一个愚蠢的问题。这是我第一次训练 GAN 网络。 这里我贴出代码:

    # Combined network
    def get_gan_network(discriminator, shape, generator, optimizer, loss):
        discriminator.trainable = False
        gan_input = Input(shape=shape)
        x = generator(gan_input)
        gan_output = discriminator(x)
        gan = Model(inputs=gan_input, outputs=[x,gan_output])
        gan.compile(loss=[loss, "binary_crossentropy"],
                    loss_weights=[1., 1e-3],
                    optimizer=optimizer)
    
        return gan
    
    def train(x_train_lr, x_train_hr, x_test_lr, x_test_hr, epochs, batch_size, output_dir, model_save_dir, weights_save_dir):
    
    
        loss = VGG_LOSS(image_shape)  
    
        batch_count = int(x_train_hr.shape[0] / batch_size)
        #### SI LAS IMAGENES NO SON CUADRADAS ESTO DEBERIA CAMBIAR
        shape_lr = (image_shape[0]//downscale_factor, image_shape[1]//downscale_factor, image_shape[2])
        shape_hr = x_train_hr[0].shape
        ####
        generator = Generator(shape_lr, shape_hr).generator()
        discriminator = Discriminator(image_shape).discriminator()
    
        optimizer = Utils_model.get_optimizer()
        generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
        discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)
    
        gan = get_gan_network(discriminator, shape_lr, generator, optimizer, loss.vgg_loss)
    
        loss_file = open(model_save_dir + '/losses.txt' , 'w+')
        loss_file.close()
    
        for e in range(1, epochs+1):
            print ('-'*15, 'Epoch %d' % e, '-'*15)
            for _ in tqdm(range(batch_count)):
    
                rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
    
                image_batch_hr = x_train_hr[rand_nums]
                image_batch_lr = x_train_lr[rand_nums]
    
                generated_images_sr = generator.predict(image_batch_lr)
    
                real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
                fake_data_Y = np.random.random_sample(batch_size)*0.2
    
                discriminator.trainable = True
    
                d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
                d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
                discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
    
                rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
                image_batch_hr = x_train_hr[rand_nums]
                image_batch_lr = x_train_lr[rand_nums]
    
                gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
                discriminator.trainable = False
                gan_loss = gan.train_on_batch(image_batch_lr, [image_batch_hr,gan_Y])
    
    
            print("discriminator_loss : %f" % discriminator_loss)
            print("gan_loss :", gan_loss)
            gan_loss = str(gan_loss)
    
            loss_file = open(model_save_dir + 'losses.txt' , 'a')
            loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %(e, gan_loss, discriminator_loss) )
            loss_file.close()
    
            if e == 1 or e % 5 == 0:
                Utils.plot_generated_images(output_dir, e, generator, x_test_hr, x_test_lr)
                generator.save_weights(weights_save_dir + '%d_gen_weights.h5' % e)
                discriminator.save_weights(weights_save_dir + '%d_dis_weights.h5' % e)
    
            if e % 500 == 0 or e == epochs+1:
                generator.save(model_save_dir + 'gen_model%d.h5' % e)
                discriminator.save(model_save_dir + 'dis_model%d.h5' % e)
    

    【讨论】:

      猜你喜欢
      • 2019-06-27
      • 2018-09-24
      • 2019-06-18
      • 1970-01-01
      • 2017-10-07
      • 2021-01-31
      • 2018-11-14
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多