【问题标题】:Why iterations over the same tf.data.Dataset give different data each iteration?为什么对同一个 tf.data.Dataset 的迭代每次迭代都会给出不同的数据?
【发布时间】:2021-03-30 16:35:49
【问题描述】:

我正在尝试了解 tf.data.Dataset 的工作原理。

它在文档上说take 返回一个数据集,其中包含该数据集中的一定数量的元素。然后,您可以迭代单个样本(在本例中为批次):

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=True)

# Build your input pipeline
ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

single_batch_dataset = ds.take(1)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)
# ...

输出:

tf.Tensor([2 0 6 6 8 8 6 0 3 4 8 7 5 2 5 7 8 7 1 1 1 8 6 4 0 4 3 2 4 2 1 9], shape=(32,), dtype=int64)

但是,再次对其进行迭代,会给出不同的标签:(上一个代码的延续)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)

输出:

tf.Tensor([7 3 5 6 3 1 7 9 6 1 9 3 9 8 6 7 7 1 9 7 5 2 0 7 8 1 7 8 7 0 5 0], shape=(32,), dtype=int64)
tf.Tensor([1 3 6 1 8 8 0 4 1 3 2 9 5 3 8 7 4 2 1 8 1 0 8 5 4 5 6 7 3 4 4 1], shape=(32,), dtype=int64)

鉴于数据集相同,标签不应该相同吗?

【问题讨论】:

  • 当然不是,你是在打乱数据... edit: 两次
  • @NicolasGervais - 也许人们可以想象这样一种场景:改组定义一次并用于每次迭代。
  • @jakub 啊是的,从这个角度来看是有道理的

标签: python tensorflow tensorflow-datasets


【解决方案1】:

这是因为数据文件被打乱了,数据集被打乱了dataset.shuffle()

使用dataset.shuffle(),默认情况下,每次迭代都会以不同的方式打乱数据。

可以删除shuffle_files=True 并设置参数reshuffle_each_iteration=False 以防止在不同的迭代中重新洗牌。

.take() 函数并不意味着确定性。它只会按照数据集给出的顺序从数据集中取出 N 个项目。

# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=False)

# Build your input pipeline
ds = ds.shuffle(1024, reshuffle_each_iteration=False).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

single_batch_dataset = ds.take(1)

for example in single_batch_dataset:
    image, label = example["image"], example["label"]
    print(label)
    
for example in single_batch_dataset:
    image, label = example["image"], example["label"]
    print(label)

输出:

tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)
tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-03-17
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多