【问题标题】:Issues at using the Tensorflow Datasets API with Keras将 Tensorflow 数据集 API 与 Keras 一起使用时的问题
【发布时间】:2019-10-01 16:38:09
【问题描述】:

我正在尝试拟合 CNN Keras 模型,并为其提供由 Tensorflow 的 Datasets API 处理的数据。然而,尽管遵循官方文档(参见there),我还是一次又一次地偶然发现了同一个异常:

ValueError: No data provided for "conv2d_8_input". Need data for each key in: ['conv2d_8_input']
# conv2d_8 is the first Conv2D layer of my model, see below

我正在使用来自 tensorflow-datasets 的 MNIST 数据集,图像被标准化,并且类标签被转换为 one-hot 编码。您可以看到以下代码的摘录。

test_data, train_data = tfds.load("mnist", split=Split.ALL.subsplit([1, 3]))

# [...] Images are normalized using Dataset.map method
# [...] Labels are converted into one-hot encodings as well, using tf.one_hot function

model = keras.Sequential([
    keras.layers.Conv2D(
        32,
        kernel_size=5,
        padding="same",
        input_shape=(28, 28, 1),
        activation="relu",
    ),
    keras.layers.MaxPooling2D(
        (2, 2),
        padding="same"
    ),
    keras.layers.Conv2D(
        64,
        kernel_size=5,
        padding="same",
        activation="relu"
    ),
    keras.layers.MaxPooling2D(
        (2, 2),
        padding="same"
    ),
    keras.layers.Flatten(),
    keras.layers.Dense(
        512,
        activation="relu"
    ),
    keras.layers.Dropout(rate=0.4),
    keras.layers.Dense(10, activation="softmax")
])

model.compile(
    optimizer=tf.train.AdamOptimizer(0.01),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

train_data = train_data.batch(32).repeat()
test_data = test_data.batch(32).repeat()

model.fit(
    train_data,
    epochs=10,
    steps_per_epoch=30,
    validation_data=test_data,
    validation_steps=3
) # The exception occurs at this step

我不明白为什么它不起作用,我尝试使用一次性迭代器而不是数据集来提供 fit 方法,但我得到了相同的结果。我不习惯 Keras 和 TensorFlow(我通常使用 PyTorch),所以我认为我可能遗漏了一些明显的东西。

【问题讨论】:

  • MNIST 数据集的形状与 28,28,1 不同。你在第一个 conv2d 之前的某个地方有重塑吗?例如,请参考tensorflow.org/tutorials/estimators/cnn
  • @Prabindh 好吧,看起来形状确实是 (28,28,1)。在model.fit(...) 之前的print(train_data) 会产生以下输出:<DatasetV1Adapter shapes: {image: (?, 28, 28, 1), label: (?, 10)}, types: {image: tf.float64, label: tf.uint8}>。 (第一个未知维度是批次维度。)
  • 我会假设数据集已损坏。你测试过它是否非空/不只是返回None吗?
  • @xdurch0 我能够从数据集中检索有效数据(使用一次性迭代器)并在 pyplot 中显示相应的图像和标签。

标签: python tensorflow machine-learning keras tensorflow-datasets


【解决方案1】:

对于那些在学习了关于加载图像的 TF 2.0 Beta 教程 (https://www.tensorflow.org/beta/tutorials/load_data/images) 后来到此页面的用户:

我能够通过在 preprocess_image 函数中返回一个元组来避免错误

def preprocess_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [192, 192])
image /= 255.0  # normalize to [0,1] range
return (image,image)

我没有在我的用例中使用标签,因此您可能需要进行其他更改才能按照教程进行操作

【讨论】:

    【解决方案2】:

    您可以使用as_supervised直接从tensorflow-datasets加载数据作为元组

    test_data, train_data = tfds.load("mnist", split=tfds.Split.ALL.subsplit([1, 3]), as_supervised=True)
    

    【讨论】:

      【解决方案3】:

      好的,我明白了。我启用了 Eager Execution 以查看 Keras 是否会产生更精确的异常,结果如下:

      ValueError: Output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: {'image': <tf.Tensor: id=1012, shape=(32, 28, 28, 1), dtype=float64, numpy=array([...])>, 'label': <tf.Tensor: id=1013, shape=(32, 10), dtype=uint8, numpy=array([...]), dtype=uint8)>}
      

      确实,我的数据集的组件(图像及其相关标签)具有名称(“图像”和“标签”),因为这是 tensorflow_datasets 加载它们的方式。结果,数据集上的迭代器会产生一个包含两个值的字典:“图像”和“标签”。

      但是,Keras expects a tuple 的两个值 (inputs, targets)(或三个值 (inputs, targets, sample_wheights)),它不喜欢 Dataset 迭代器产生的字典(因此我得到了错误)。

      我在model.fit之前添加了以下代码:

      train_data = train_data.map(lambda x: tuple(x.values()))
      test_data = test_data.map(lambda x: tuple(x.values()))
      

      它有效。

      【讨论】:

      • 老实说,我不喜欢这个解决方案,它看起来并不干净。但是,这是我发现的唯一可以缓解 Keras 限制的解决方法。
      • 如果您使用的是 TF 版本的 keras,理想情况下不需要这样做。见github.com/tensorflow/tensorflow/issues/20698
      • @Prabindh 我不确定这是否完全相关,因为我没有多输入模型?我是否应该在 github 问题上提及我的案例?
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-10-14
      • 1970-01-01
      • 2018-05-01
      • 1970-01-01
      相关资源
      最近更新 更多