【问题标题】:How to use TensorFlow Dataset API in GANs training?如何在 GAN 训练中使用 TensorFlow Dataset API?
【发布时间】:2019-01-28 23:29:20
【问题描述】:

我正在训练 GAN 模型。为了加载数据集,我使用的是 TensorFlow 的 Dataset API。

# train_dataset has image and label. z_train dataset has noise (z).
train_dataset = tf.data.TFRecordDataset(train_file)
z_train = tf.data.Dataset.from_tensor_slices(tf.random_uniform([total_training_samples, seq_length,  z_dim],
                                                                 minval=0, maxval=1, dtype=tf.float32))

train_dataset = tf.data.Dataset.zip((train_dataset, z_train))

创建迭代器:

iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

使用迭代器:

(img, label), z = iter.get_next()
train_init_op = iter.make_initializer(train_dataset)

在会话中训练 GAN 时:
首先训练判别器:

_, disc_loss = sess.run([disc_optim, disc_loss])

然后训练生成器:

_, gen_loss = sess.run([gen_optim, gen_loss])

这就是问题所在。因为,我在鉴别器和生成器图中都使用 label 作为条件 (CGAN),因此在同一运行期间使用两个 sess.run 会产生两组不同的 label批次。

for epoch in range(num_of_epochs):
    sess.run([tf.global_variables_initializer(), train_init_op.initializer])
    for batch in range(num_of_batches):
        _, disc_loss = sess.run([disc_optim, disc_loss])
        _, gen_loss = sess.run([gen_optim, gen_loss])

既然,我必须在生成器的会话运行中提供与鉴别器的会话运行相同的批次 label,我应该如何防止 Dataset API 在批次的同一个循环中产生两个不同的批次?
注意:我使用的是 TensorFlow v1.9
提前致谢。

【问题讨论】:

  • ...你能在同一个会话调用中运行所有操作吗? sess.run([disc_optim, disc_loss, get_optim, gen_loss])?不确定我是否完全理解这个问题,但也许可以在GANEstimator 中查看他们是如何做到的。另外,您的z_train 是数据集的一部分,而不仅仅是tf.random_uniform 的输出,有什么原因吗?
  • 在训练生成器之前,我正在检查判别器的损失是否大于 0.2,然后只训练判别器。为此,我需要运行两个不同的会话调用。我也会检查 GANEstimator。谢谢你的建议。 z_train 数据集实际上是嵌入在 Dataset 包装器中的 tf.random_uniform,因此 imglabelz 可以作为一个批次一起提取,而无需使用feed_dict
  • 通过数据集抽取 z 只会导致不必要的数据从 cpu 传输到 gpu。您可以让数据集在数据集之外处理 img、label 和 set z = tf.random...。您还可以使用tf.cond 进行条件操作 - 类似于gen_optim = tf.cond(disc_loss < 0.2, lambda: gen_optim, tf.no_op)cond 可能存在一些问题,期望 fn 输出具有相同的形状/类型,但这可以使用 tf.control_dependencies 或类似方法解决

标签: tensorflow tensorflow-datasets


【解决方案1】:

您可以为同一个数据集创建 2 个迭代器。如果您需要打乱数据集,您甚至可以通过将种子指定为张量来实现。请参见下面的示例。

import tensorflow as tf

seed_ts = tf.placeholder(tf.int64)
ds = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).shuffle(5, seed=seed_ts, reshuffle_each_iteration=True)
it1 = ds.make_initializable_iterator()
it2 = ds.make_initializable_iterator()

input1 = it1.get_next()
input2 = it2.get_next()

with tf.Session() as sess:
    for ep in range(10):
        sess.run(it1.initializer, feed_dict={seed_ts: ep})
        sess.run(it2.initializer, feed_dict={seed_ts: ep})

        print("Epoch" + str(ep))
        for i in range(5):
            x = sess.run(input1)
            y = sess.run(input2)
            print([x, y])

【讨论】:

    猜你喜欢
    • 2019-10-08
    • 1970-01-01
    • 2018-03-03
    • 2019-01-01
    • 2020-03-06
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-11-20
    相关资源
    最近更新 更多