【发布时间】:2019-10-01 16:38:09
【问题描述】:
我正在尝试拟合 CNN Keras 模型,并为其提供由 Tensorflow 的 Datasets API 处理的数据。然而,尽管遵循官方文档(参见there),我还是一次又一次地偶然发现了同一个异常:
ValueError: No data provided for "conv2d_8_input". Need data for each key in: ['conv2d_8_input']
# conv2d_8 is the first Conv2D layer of my model, see below
我正在使用来自 tensorflow-datasets 的 MNIST 数据集,图像被标准化,并且类标签被转换为 one-hot 编码。您可以看到以下代码的摘录。
test_data, train_data = tfds.load("mnist", split=Split.ALL.subsplit([1, 3]))
# [...] Images are normalized using Dataset.map method
# [...] Labels are converted into one-hot encodings as well, using tf.one_hot function
model = keras.Sequential([
keras.layers.Conv2D(
32,
kernel_size=5,
padding="same",
input_shape=(28, 28, 1),
activation="relu",
),
keras.layers.MaxPooling2D(
(2, 2),
padding="same"
),
keras.layers.Conv2D(
64,
kernel_size=5,
padding="same",
activation="relu"
),
keras.layers.MaxPooling2D(
(2, 2),
padding="same"
),
keras.layers.Flatten(),
keras.layers.Dense(
512,
activation="relu"
),
keras.layers.Dropout(rate=0.4),
keras.layers.Dense(10, activation="softmax")
])
model.compile(
optimizer=tf.train.AdamOptimizer(0.01),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
train_data = train_data.batch(32).repeat()
test_data = test_data.batch(32).repeat()
model.fit(
train_data,
epochs=10,
steps_per_epoch=30,
validation_data=test_data,
validation_steps=3
) # The exception occurs at this step
我不明白为什么它不起作用,我尝试使用一次性迭代器而不是数据集来提供 fit 方法,但我得到了相同的结果。我不习惯 Keras 和 TensorFlow(我通常使用 PyTorch),所以我认为我可能遗漏了一些明显的东西。
【问题讨论】:
-
MNIST 数据集的形状与 28,28,1 不同。你在第一个 conv2d 之前的某个地方有重塑吗?例如,请参考tensorflow.org/tutorials/estimators/cnn
-
@Prabindh 好吧,看起来形状确实是 (28,28,1)。在
model.fit(...)之前的print(train_data)会产生以下输出:<DatasetV1Adapter shapes: {image: (?, 28, 28, 1), label: (?, 10)}, types: {image: tf.float64, label: tf.uint8}>。 (第一个未知维度是批次维度。) -
我会假设数据集已损坏。你测试过它是否非空/不只是返回
None吗? -
@xdurch0 我能够从数据集中检索有效数据(使用一次性迭代器)并在 pyplot 中显示相应的图像和标签。
标签: python tensorflow machine-learning keras tensorflow-datasets