【发布时间】:2020-02-14 15:15:24
【问题描述】:
我对 TensorFlow 和 Keras 完全陌生,我正在尝试尝试一些我在网上找到的代码。
特别是我正在使用时尚MNIST - 由 60000 个示例和 10000 个示例的测试集组成。它们中的每一个都是 28x28 灰度图像。
我正在关注本教程“https://towardsdatascience.com/building-your-first-neural-network-in-tensorflow-2-tensorflow-for-hackers-part-i-e1e2f1dfe7a0”,直到定义
history = model.fit(
train_dataset.repeat(),
epochs=10,
steps_per_epoch=500,
validation_data=val_dataset.repeat(),
validation_steps=2)
只要我理解,我需要使用 train_dataset.repeat() 作为输入数据集,否则我将没有足够的训练示例使用这些值作为超参数(epochs、steps_per_epochs)。
我的问题是:如何避免必须使用 .repeat()? 我需要如何更改超参数?
为了简单起见,我在这里处理代码:
def preprocess(x,y):
x = tf.cast(x,tf.float32) / 255.0
y = tf.cast(y, tf.float32)
return x,y
def create_dataset(xs, ys, n_classes=10):
ys = tf.one_hot(ys, depth=n_classes)
return tf.data.Dataset.from_tensor_slices((xs, ys)).map(preprocess).shuffle(len(ys)).batch(128)
model.compile(optimizer = 'adam', loss =tf.losses.CategoricalCrossentropy(from_logits= True), metrics =['accuracy'])
history1 = model.fit(train_dataset.repeat(),
epochs=10,
steps_per_epoch=500,
validation_data=val_dataset.repeat(),
validation_steps=2)
谢谢!
【问题讨论】:
标签: python tensorflow keras neural-network