【发布时间】:2018-11-12 23:14:41
【问题描述】:
MNIST 的 TensorFlow 文档推荐了多种不同的方式来加载 MNIST 数据集:
- https://www.tensorflow.org/tutorials/layers
- https://www.tensorflow.org/versions/r1.2/get_started/mnist/beginners
- https://www.tensorflow.org/versions/r1.2/get_started/mnist/pros
文档中描述的所有方式都会在 TensorFlow 1.8 中引发许多已弃用的警告。
我目前加载 MNIST 和创建训练批次的方式:
class MNIST:
def __init__(self, optimizer):
...
self.mnist_dataset = input_data.read_data_sets("/tmp/data/", one_hot=True)
self.test_data = self.mnist_dataset.test.images.reshape((-1, self.timesteps, self.num_input))
self.test_label = self.mnist_dataset.test.labels
...
def train_run(self, sess):
batch_input, batch_output = self.mnist_dataset.train.next_batch(self.batch_size, shuffle=True)
batch_input = batch_input.reshape((self.batch_size, self.timesteps, self.num_input))
_, loss = sess.run(fetches=[self.train_step, self.loss], feed_dict={self.input_placeholder: batch_input, self.output_placeholder: batch_output})
...
def test_run(self, sess):
loss = sess.run(fetches=[self.loss], feed_dict={self.input_placeholder: self.test_data, self.output_placeholder: self.test_label})
...
我怎么能做完全一样的事情,只是用目前的方法?
我找不到这方面的任何文档。
在我看来,新方法是这样的:
train, test = tf.keras.datasets.mnist.load_data()
self.mnist_train_ds = tf.data.Dataset.from_tensor_slices(train)
self.mnist_test_ds = tf.data.Dataset.from_tensor_slices(test)
但是如何在我的train_run 和test_run 方法中使用这些数据集?
【问题讨论】:
标签: tensorflow mnist