【问题标题】:How can a class label be extracted from the filename returned by tf.WholeFileReader.read() in TensorFlow?如何从 TensorFlow 中 tf.WholeFileReader.read() 返回的文件名中提取类标签?
【发布时间】:2017-08-14 07:08:04
【问题描述】:

我有一堆图像组织在子目录中,它们对应于图像的类标签。例如

  • images/1/0000001.jpg, images/1/0000002.jpg, ... 用于 1 类图像
  • images/2/0123456.jpg, images/2/0123457.jpg, ... 用于第 2 类图像

现在,我想知道在使用 tf.WholeFileReader() 时如何获得整数类标签,该方法具有将文件名作为张量生成的 read 方法。在图表之外,我可以简单地执行int('images/2/0123457.jpg'.split('/')[1]) 来获取整数标签,但是如何在图表内部执行此操作以便可以使用标签进行模型训练?下面是一个幼稚的例子,我基本上是在下面这个例子中寻找class_label = ... # get class label from file_name的解决方案:

import tensorflow as tf


g = tf.Graph()
with g.as_default():

    filename_queue = tf.train.string_input_producer(
        tf.train.match_filenames_once('images/*/*.jpg'))

    image_reader = tf.WholeFileReader()

    file_name, image_raw = image_reader.read(filename_queue)
    file_name = tf.identity(file_name, name='file_name')

    image = tf.image.decode_jpeg(image_raw, name='image')
    class_label = ... # get class label from file_name



with tf.Session(graph=g) as sess:

    sess.run(tf.local_variables_initializer())

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    image_tensor = sess.run('image:0')
    print('Image shape:', image_tensor.shape)

    file_name = sess.run('file_name:0')
    print('File name:', file_name)    

    coord.request_stop()
    coord.join(threads)

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    刚刚使用tf.split_stringtf.string_to_number 找到了我的问题的解决方案:

    class_label = tf.string_split([file_name], '/').values[1]
    class_label = tf.string_to_number(class_label, tf.int32)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-07-10
      • 1970-01-01
      • 1970-01-01
      • 2017-08-19
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-10-07
      相关资源
      最近更新 更多