【发布时间】:2018-08-17 14:02:33
【问题描述】:
跟随本教程:https://www.tensorflow.org/versions/r1.3/get_started/mnist/pros
我想自己解决一个带有标签图像的分类问题。由于我没有使用 MNIST 数据库,因此我花了几天时间在 tensorflow 中创建自己的数据集。它看起来像这样:
#variables
batch_size = 50
dimension = 784
stages = 10
#step 1 read Dataset
filenames = tf.constant(filenamesList)
labels = tf.constant(labelsList)
#step 2 create Dataset
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
#step 3: parse every image in the dataset using `map`
def _parse_function(filename, label):
#convert label to one-hot encoding
one_hot = tf.one_hot(label, stages)
#read image file
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
return image, one_hot
#step 4 final input tensor
dataset = dataset.map(_parse_function)
dataset = dataset.batch(batch_size) #batch_size = 100
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
images = tf.reshape(images, [batch_size,dimension]).eval()
labels = tf.reshape(labels, [batch_size,stages]).eval()
for _ in range(10):
dataset = dataset.shuffle(buffer_size = 100)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
images = tf.reshape(images, [batch_size,dimension]).eval()
labels = tf.reshape(labels, [batch_size,stages]).eval()
train_step.run(feed_dict={x: images, y_:labels})
不知何故,使用 更高的 batch_size 会破坏 python。我正在尝试做的是在每次迭代中用新批次训练我的神经网络。这就是为什么我也使用 dataset.shuffle(...)。使用 dataset.shuffle 也会破坏我的 Python。
我想做的(因为随机播放中断)是批处理整个数据集。通过评估 ('.eval()') 我将得到一个 numpy 数组。然后我将使用 numpy.random.shuffle(images) 对数组进行洗牌,然后挑选一些第一个元素来训练它。
例如
for _ in range(1000):
images = tf.reshape(images, [batch_size,dimension]).eval()
labels = tf.reshape(labels, [batch_size,stages]).eval()
#shuffle
np.random.shuffle(images)
np.random.shuffle(labels)
train_step.run(feed_dict={x: images[0:train_size], y_:labels[0:train_size]})
但随之而来的问题是我无法对整个数据集进行批处理。看起来数据太大了,python 无法使用。 我应该如何以不同的方式解决这个问题?
由于我没有使用 MNIST 数据库,因此没有像 mnist.train.next_batch(100) 这样对我很方便的函数。
【问题讨论】:
标签: python tensorflow dataset mnist tensorflow-datasets