【问题标题】:Confused about the use of validation set here对此处验证集的使用感到困惑
【发布时间】:2019-01-22 18:55:03
【问题描述】:

对于px2graph项目的main.py,训练和验证部分如下图:

splits = [s for s in ['train', 'valid'] if opt.iters[s] > 0]
start_round = opt.last_round - opt.num_rounds

# Main training loop
for round_idx in range(start_round, opt.last_round):
    for split in splits:

        print("Round %d: %s" % (round_idx, split))
        loader.start_epoch(sess, split, train_flag, opt.iters[split] * opt.batchsize)

        flag_val = split == 'train'

        for step in tqdm(range(opt.iters[split]), ascii=True):
            global_step = step + round_idx * opt.iters[split]
            to_run = [sample_idx, summaries[split], loss, accuracy]
            if split == 'train': to_run += [optim]

            # Do image summaries at the end of each round
            do_image_summary = step == opt.iters[split] - 1
            if do_image_summary: to_run[1] = image_summaries[split]

            # Start with lower learning rate to prevent early divergence
            t = 1/(1+np.exp(-(global_step-5000)/1000))
            lr_start = opt.learning_rate / 15
            lr_end = opt.learning_rate
            tmp_lr = (1-t) * lr_start + t * lr_end

            # Run computation graph
            result = sess.run(to_run, feed_dict={train_flag:flag_val, lr:tmp_lr})

            out_loss = result[2]
            out_accuracy = result[3]
            if sum(out_loss) > 1e5:
                print("Loss diverging...exiting before code freezes due to NaN values.")
                print("If this continues you may need to try a lower learning rate, a")
                print("different optimizer, or a larger batch size.")
                return

            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, global_step, out_loss, out_accuracy))

            # Log data
            if split == 'valid' or (split == 'train' and step % 20 == 0) or do_image_summary:
                writer.add_summary(result[1], global_step)
                writer.flush()

    # Save training snapshot
    saver.save(sess, 'exp/' + opt.exp_id + '/snapshot')
    with open('exp/' + opt.exp_id + '/last_round', 'w') as f:
        f.write('%d\n' % round_idx)

好像作者只得到了每批验证集的结果。我想知道,如果我想观察模型是否在改进或达到最佳性能,我应该在整个验证集上使用结果吗?

【问题讨论】:

    标签: validation tensorflow train-test-split


    【解决方案1】:

    如果验证集足够小,我们可以在训练期间计算整个验证集的损失、准确度以观察性能。但是,如果验证集太大,最好按批次计算验证结果并进行多个步骤。

    【讨论】:

      猜你喜欢
      • 2013-06-02
      • 1970-01-01
      • 2021-04-06
      • 2021-10-22
      • 1970-01-01
      • 1970-01-01
      • 2023-04-06
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多