【发布时间】:2019-01-28 23:29:20
【问题描述】:
我正在训练 GAN 模型。为了加载数据集,我使用的是 TensorFlow 的 Dataset API。
# train_dataset has image and label. z_train dataset has noise (z).
train_dataset = tf.data.TFRecordDataset(train_file)
z_train = tf.data.Dataset.from_tensor_slices(tf.random_uniform([total_training_samples, seq_length, z_dim],
minval=0, maxval=1, dtype=tf.float32))
train_dataset = tf.data.Dataset.zip((train_dataset, z_train))
创建迭代器:
iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
使用迭代器:
(img, label), z = iter.get_next()
train_init_op = iter.make_initializer(train_dataset)
在会话中训练 GAN 时:
首先训练判别器:
_, disc_loss = sess.run([disc_optim, disc_loss])
然后训练生成器:
_, gen_loss = sess.run([gen_optim, gen_loss])
这就是问题所在。因为,我在鉴别器和生成器图中都使用 label 作为条件 (CGAN),因此在同一运行期间使用两个 sess.run 会产生两组不同的 label批次。
for epoch in range(num_of_epochs):
sess.run([tf.global_variables_initializer(), train_init_op.initializer])
for batch in range(num_of_batches):
_, disc_loss = sess.run([disc_optim, disc_loss])
_, gen_loss = sess.run([gen_optim, gen_loss])
既然,我必须在生成器的会话运行中提供与鉴别器的会话运行相同的批次 label,我应该如何防止 Dataset API 在批次的同一个循环中产生两个不同的批次?
注意:我使用的是 TensorFlow v1.9
提前致谢。
【问题讨论】:
-
...你能在同一个会话调用中运行所有操作吗?
sess.run([disc_optim, disc_loss, get_optim, gen_loss])?不确定我是否完全理解这个问题,但也许可以在GANEstimator 中查看他们是如何做到的。另外,您的z_train是数据集的一部分,而不仅仅是tf.random_uniform的输出,有什么原因吗? -
在训练生成器之前,我正在检查判别器的损失是否大于 0.2,然后只训练判别器。为此,我需要运行两个不同的会话调用。我也会检查 GANEstimator。谢谢你的建议。 z_train 数据集实际上是嵌入在 Dataset 包装器中的 tf.random_uniform,因此 img、label 和 z 可以作为一个批次一起提取,而无需使用feed_dict
-
通过数据集抽取 z 只会导致不必要的数据从 cpu 传输到 gpu。您可以让数据集在数据集之外处理 img、label 和 set z = tf.random...。您还可以使用
tf.cond进行条件操作 - 类似于gen_optim = tf.cond(disc_loss < 0.2, lambda: gen_optim, tf.no_op)。cond可能存在一些问题,期望 fn 输出具有相同的形状/类型,但这可以使用tf.control_dependencies或类似方法解决
标签: tensorflow tensorflow-datasets