【发布时间】:2018-03-07 19:18:08
【问题描述】:
def train():
# Model
model = Model()
# Loss, Optimizer
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step')
loss_fn = model.loss()
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step)
# Summaries
summary_op = summaries(model, loss_fn)
with tf.Session(config=TrainConfig.session_conf) as sess:
# Initialized, Load state
sess.run(tf.global_variables_initializer())
model.load_state(sess, TrainConfig.CKPT_PATH)
writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph)
# Input source
data = Data(TrainConfig.DATA_PATH)
loss = Diff()
for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP):
mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step)
mixed_spec = to_spectrogram(mixed_wav)
mixed_mag = get_magnitude(mixed_spec)
src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav)
src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec)
src1_batch, _ = model.spec_to_batch(src1_mag)
src2_batch, _ = model.spec_to_batch(src2_mag)
mixed_batch, _ = model.spec_to_batch(mixed_mag)
# Initializae our callback.
#early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5)
l, _, summary = sess.run([loss_fn, optimizer, summary_op],
feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch,
model.y_src2: src2_batch})
loss.update(l)
print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value))
writer.add_summary(summary, global_step=step)
# Save state
if step % TrainConfig.CKPT_STEP == 0:
tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step)
writer.close()
我有这个神经网络代码,可以将音乐与 .wav 文件中的声音分开。 如何引入提前停止算法来停止火车部分?我看到一些谈论 ValidationMonitor 的项目。有人可以帮我吗?
【问题讨论】:
-
从最新的 TensorFlow 文档(测试版)中,可以使用自定义回调实现提前停止。 tensorflow.org/beta/guide/keras/…
-
好吧,我提供的链接直接指向一个示例回调类,
EarlyStoppingAtMinLoss。该类的一个实例可以在训练期间作为回调传递给模型,并在训练期间用于在损失停止减少时提前停止。该示例给出了类的实现,以及在训练期间如何使用它。此外,文档中还提到了一个额外的回调,“tf.keras.callbacks.EarlyStopping 提供了更完整和通用的实现”。这是回调:tensorflow.org/versions/r2.0/api_docs/python/tf/keras/callbacks/…
标签: python tensorflow keras neural-network