【问题标题】:How to iterate over multiple datasets in TensorFlow 2如何在 TensorFlow 2 中迭代多个数据集
【发布时间】:2020-10-03 01:29:31
【问题描述】:

我使用 TensorFlow 2.2.0。在我的数据管道中,我使用多个数据集来训练神经网络。比如:

# these are all tf.data.Dataset objects:
paired_data = get_dataset(id=0, repeat=False, shuffle=True)
unpaired_images = get_dataset(id=1, repeat=True, shuffle=True)
unpaired_masks = get_dataset(id=2, repeat=True, shuffle=True)

在训练循环中,我想遍历paired_data 来定义一个时期。但我也想迭代 unpaired_imagesunpaired_masks 以优化其他目标(用于语义分割的经典半监督学习,带有掩码鉴别器)。

为了做到这一点,我当前的代码如下:

def train_one_epoch(self, writer, step, paired_data, unpaired_images, unpaired_masks):

    unpaired_images = unpaired_images.as_numpy_iterator()
    unpaired_masks = unpaired_masks.as_numpy_iterator()

    for images, labels in paired_data:

        with tf.GradientTape() as sup_tape, \
                tf.GradientTape() as gen_tape, \
                tf.GradientTape() as disc_tape:

            # paired data (supervised cost):
            predictions = segmentor(images, training=True)
            sup_loss = weighted_cross_entropy(predictions, labels)

            # unpaired data (adversarial cost):
            pred_real = discriminator(next(unpaired_masks), training=True)
            pred_fake = discriminator(segmentor(next(unpaired_images), training=True), training=True)
            gen_loss = generator_loss(pred_fake)
            disc_loss = discriminator_loss(pred_real, pred_fake)

        gradients = sup_tape.gradient(sup_loss, self.segmentor.trainable_variables)
        generator_optimizer.apply_gradients(zip(gradients, self.segmentor.trainable_variables))

        gradients = gen_tape.gradient(gen_loss, self.segmentor.trainable_variables)
        generator_optimizer.apply_gradients(zip(gradients, self.segmentor.trainable_variables))

        gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        discriminator_optimizer.apply_gradients(zip(gradients, self.discriminator.trainable_variables))

但是,这会导致错误:

main.py:275 train_one_epoch  *
        unpaired_images = unpaired_images.as_numpy_iterator()
    /home/venvs/conda/miniconda3/envs/tf-gpu/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py:476 as_numpy_iterator  **
        raise RuntimeError("as_numpy_iterator() is not supported while tracing "

    RuntimeError: as_numpy_iterator() is not supported while tracing functions

知道这有什么问题吗?这是在 tensorflow 2 中优化多个损失/数据集的正确方法吗?


我将我当前的解决方案添加到 cmets 中的问题。任何关于更优化方式的建议都非常受欢迎! :)

【问题讨论】:

  • 我注意到 pred_fake = discriminator(segmentor(next unpaired_images), training=True), training=True) 有语法错误。在调用 next 时,您忘记在 unpaired_images 周围加上括号。不知道这是否是问题。
  • 嗨,Richard X,感谢您注意到这一点。不幸的是,这只是复制粘贴时的一个错字......我要修复它

标签: python tensorflow iterator dataset


【解决方案1】:

我目前的解决方案:

def train_one_epoch(self, writer, step, paired_data, unpaired_images, unpaired_masks):

    # create a new dataset zipping the three original dataset objects
    dataset = tf.data.Dataset.zip((paired_data, unpaired_images, unpaired_masks))
    dataset = dataset.batch(1)

    for (images, labels), unpaired_images, unpaired_masks in dataset:

        # access the elements as the first and only element of the batched dataset
        images, labels, unpaired_images, unpaired_masks = \
            images[0], labels[0], unpaired_images[0], unpaired_masks[0]

        # go ahead and train:
        with tf.GradientTape() as tape:
            #[...]

【讨论】:

    猜你喜欢
    • 2018-04-14
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-06-27
    • 2018-01-26
    • 2019-05-12
    • 1970-01-01
    • 2019-04-03
    相关资源
    最近更新 更多