【发布时间】:2017-08-14 07:08:04
【问题描述】:
我有一堆图像组织在子目录中,它们对应于图像的类标签。例如
- images/1/0000001.jpg, images/1/0000002.jpg, ... 用于 1 类图像
- images/2/0123456.jpg, images/2/0123457.jpg, ... 用于第 2 类图像
现在,我想知道在使用 tf.WholeFileReader() 时如何获得整数类标签,该方法具有将文件名作为张量生成的 read 方法。在图表之外,我可以简单地执行int('images/2/0123457.jpg'.split('/')[1]) 来获取整数标签,但是如何在图表内部执行此操作以便可以使用标签进行模型训练?下面是一个幼稚的例子,我基本上是在下面这个例子中寻找class_label = ... # get class label from file_name的解决方案:
import tensorflow as tf
g = tf.Graph()
with g.as_default():
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once('images/*/*.jpg'))
image_reader = tf.WholeFileReader()
file_name, image_raw = image_reader.read(filename_queue)
file_name = tf.identity(file_name, name='file_name')
image = tf.image.decode_jpeg(image_raw, name='image')
class_label = ... # get class label from file_name
with tf.Session(graph=g) as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
image_tensor = sess.run('image:0')
print('Image shape:', image_tensor.shape)
file_name = sess.run('file_name:0')
print('File name:', file_name)
coord.request_stop()
coord.join(threads)
【问题讨论】:
标签: python tensorflow