【问题标题】:keras training on big datasets seperately keraskeras 在大数据集上单独训练
【发布时间】:2020-11-18 13:22:46
【问题描述】:

我正在研究对高维 X 射线图像进行去噪的 keras 去噪神经网络。这个想法是在一些数据集上进行训练,例如 1、2、3,在获得权重之后,另一个数据集例如 4、5、6 将开始新的训练,权重从之前的训练中初始化。在实现方面它可以工作,但是最后一次旋转产生的权重仅在用于在此轮换中训练的数据集上表现更好。其他轮换也是如此。

换句话说,从数据集训练得到的权重:4,5,6 在数据集 1 的图像上没有像在数据集:1,2 上训练的权重那样给出良好的结果, 3.这不应该是我打算做的事情

这个想法是应该调整权重以有效地处理所有数据集,因为对整个数据集的训练不适合内存。

我尝试了其他解决方案,例如创建自定义生成器,该生成器从磁盘获取图像并批量进行训练,这非常慢,因为它取决于磁盘上发生的 I/O 操作或处理函数的时间复杂度等因素发生在自定义 keras 生成器中!

下面是显示我在做什么的代码。我有 12 个数据集,分为 4 个检查点。数据被加载,训练开始并将最终模型保存到一个数组中,下一次训练从上一次旋转中获取权重并继续。

EPOCHES = 150
NUM_CHKPTS = 4
weights = []

for chk in range(1,NUM_CHKPTS+1):

    log_dir = os.path.join(os.getcwd(), 'resnet_checkpts_' + str(EPOCHES) + "_tl2_chkpt" + str(chk))
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)
    else:
        print('Training log directory already exists @ {}.'.format(log_dir))
    tb_output = TensorBoard(log_dir=log_dir, histogram_freq=1)

    print("Loading Data From CHKPT #" + str(chk))

    h5f = h5py.File('C:\\autoencoder\\datasets\\mix\\chk' + str(chk) + '.h5','r')
    org_patch = h5f['train_data'][:]
    noisy_patch = h5f['train_noisy'][:]
    h5f.close()

    input_patch, test_patch, noisy_patch, test_noisy_patch = train_test_split(org_patch, noisy_patch, train_size=0.8, shuffle=True)


    print("Reshaping")
    train_data = np.array([np.reshape(input_patch[i], (52, 52, 1)) for i in range(input_patch.shape[0])], dtype = np.float32)
    train_noisy_data = np.array([np.reshape(noisy_patch[i], (52, 52, 1)) for i in range(noisy_patch.shape[0])], dtype = np.float32)

    test_data = np.array([np.reshape(test_patch[i], (52, 52, 1)) for i in range(test_patch.shape[0])], dtype = np.float32)
    test_noisy_data = np.array([np.reshape(test_noisy_patch[i], (52, 52, 1)) for i in range(test_noisy_patch.shape[0])], dtype = np.float32)

    print('Number of training samples are:', train_data.shape[0])
    print('Number of test samples are:', test_data.shape[0])

    # IN = np.ones((len(XTRAINFILES), 52, 52, 1 ))

    if chk == 1:
        print("Generating the Model For The First Time..")
        autoencoder_model = model_autoencoder(train_noisy_data)
        print("Done!")
    else:
        autoencoder_model=load_model(weights[chk-2])


    checkpt_path = log_dir + r"\\cp-{epoch:04d}.ckpt"

    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpt_path, verbose=0, save_weights_only=True, save_freq='epoch')


    optimizer = tf.keras.optimizers.Adam(lr=0.0001)
    autoencoder_model.compile(loss='mse',optimizer=optimizer) 

    autoencoder_model.fit(train_noisy_data, train_data,
                        batch_size=128,
                        epochs=EPOCHES, shuffle=True, verbose=1,
                        validation_data=(test_noisy_data, test_data),
                        callbacks=[tb_output, checkpoint_callback])

    weight_dir = log_dir+'\\model_resnet_new_OL' + str(EPOCHES) + 'epochs.h5'
    weights.append(weight_dir)
    autoencoder_model.save(weight_dir)  # Defined saved model name by number of epochs.

Tensorboard Graphs,自上而下的旋转为 1,2,3,4:

【问题讨论】:

    标签: keras deep-learning neural-network autoencoder transfer-learning


    【解决方案1】:

    当您在新数据集上进行训练时,您的模型会忘记之前的数据集。

    我在强化学习中看到,当游戏用于训练深度强化学习(DRL)时,你必须创建记忆回放,它从不同回合的游戏中收集数据,因为每一轮游戏有不同的数据,然后随机选择其中一些数据来训练模型。这样 DRL 模型可以学习玩不同回合的游戏而不会忘记之前的回合。

    您可以尝试通过从每个数据集中抽取一些随机样本来创建单个数据集。

    当您在新数据集上训练模型时,确保所有先前轮换的数据都在当前轮换中。

    同样在迁移学习中,当你在新数据集上训练模型时,你必须冻结之前的层,这样模型就不会忘记之前的训练。您没有使用迁移学习,但是当您开始对第二个数据集进行训练时,您的第一个数据集将慢慢从权重内存中删除。

    您可以尝试冻结解码器的初始层,以便它们在提取特征时不会更新,假设所有数据集都包含相似的图像,这样您的模型就不会忘记之前的训练,就像在迁移学习中一样。但是当您在新数据集上进行训练时,之前的数据仍然会被遗忘。

    【讨论】:

    • 我不这么认为,张量板图的情况恰恰相反,我将添加显示损失的数字
    • 那么您使用的不同数据集之间可能存在巨大差异。
    • 尝试使用来自所有其他数据集的数据,特别是每个数据集的不同图像,创建 1 个数据集进行训练。
    • 当你训练很多时,模型会忘记之前的训练,在不同的数据集上一次又一次地训练是没有意义的。
    • 通常初始层提取特征,而最后一层根据初始层的特征做出决策。所以你必须冻结初始层。这将使模型在对新数据进行训练时从图像中提取特征而不更新权重。
    猜你喜欢
    • 2021-05-17
    • 2020-07-22
    • 1970-01-01
    • 1970-01-01
    • 2017-07-29
    • 2021-11-23
    • 1970-01-01
    • 2016-10-19
    • 2020-06-25
    相关资源
    最近更新 更多