【问题标题】:Modifying label_image.py in TensorFlow tutorial to classify multiple imagesTensorFlow教程中修改label_image.py,对多张图片进行分类
【发布时间】:2018-03-10 18:43:04
【问题描述】:

我已经根据我自己的数据重新训练了一个 InceptionV3 模型,并正在尝试修改来自此处 https://www.tensorflow.org/tutorials/image_recognition 的 Tensorflow 图像分类教程中的代码。

我尝试将目录作为列表读取并循环遍历它,但这不起作用:

  load_graph(FLAGS.graph)

filelist = os.listdir(FLAGS.image)

for i in filelist:
  # load image
  image_data = load_image(i)

我只是收到一个错误,说 FLAGS 尚未定义,所以我猜 FLAGS 必须与 load_image 函数一起使用?这是原始程序:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import os
import tensorflow as tf

parser = argparse.ArgumentParser()
parser.add_argument(
    '--image', required=True, type=str, help='Absolute path to image file.')
parser.add_argument(
    '--num_top_predictions',
    type=int,
    default=5,
    help='Display this many predictions.')
parser.add_argument(
    '--graph',
    required=True,
    type=str,
    help='Absolute path to graph file (.pb)')
parser.add_argument(
    '--labels',
    required=True,
    type=str,
    help='Absolute path to labels file (.txt)')
parser.add_argument(
    '--output_layer',
    type=str,
    default='final_result:0',
    help='Name of the result operation')
parser.add_argument(
    '--input_layer',
    type=str,
    default='DecodeJpeg/contents:0',
    help='Name of the input operation')


def load_image(filename):
  """Read in the image_data to be classified."""
  return tf.gfile.FastGFile(filename, 'rb').read()


def load_labels(filename):
  """Read in labels, one label per line."""
  return [line.rstrip() for line in tf.gfile.GFile(filename)]


def load_graph(filename):
  """Unpersists graph from file as default graph."""
  with tf.gfile.FastGFile(filename, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')


def run_graph(image_data, labels, input_layer_name, output_layer_name,
              num_top_predictions):
  with tf.Session() as sess:
    # Feed the image_data as input to the graph.
    #   predictions will contain a two-dimensional array, where one
    #   dimension represents the input image count, and the other has
    #   predictions per class
    softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
    predictions, = sess.run(softmax_tensor, {input_layer_name: image_data})

    # Sort to show labels in order of confidence
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = labels[node_id]
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))

    return 0


def main(argv):
  """Runs inference on an image."""
  if argv[1:]:
    raise ValueError('Unused Command Line Args: %s' % argv[1:])

  if not tf.gfile.Exists(FLAGS.image):
    tf.logging.fatal('image file does not exist %s', FLAGS.image)

  if not tf.gfile.Exists(FLAGS.labels):
    tf.logging.fatal('labels file does not exist %s', FLAGS.labels)

  if not tf.gfile.Exists(FLAGS.graph):
    tf.logging.fatal('graph file does not exist %s', FLAGS.graph)


  # load image
  image_data = load_image(FLAGS.image)

  # load labels
  labels = load_labels(FLAGS.labels)

  # load graph, which is stored in the default session
  load_graph(FLAGS.graph)

  run_graph(image_data, labels, FLAGS.input_layer, FLAGS.output_layer,
            FLAGS.num_top_predictions)


if __name__ == '__main__':
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=sys.argv[:1]+unparsed)

【问题讨论】:

    标签: image loops tensorflow classification


    【解决方案1】:

    感谢您提供的帮助,FLAGS 来自 argparser 模块,而不是 TensorFlow 标志模块,并且 FLAGS 可能必须从函数中调用。我最终通过制作一个单独的函数解决了这个问题,所以我认为这就是正在发生的事情:

    def get_image_list(path):
        return glob.glob(path + '*.jpg')
    

    然后再向下调用一个循环:

    filelist = get_image_list(FLAGS.image)
    
      for i in filelist:
            image_data = load_image(i)
    
            run_graph(image_data, labels, FLAGS.input_layer, FLAGS.output_layer,
                FLAGS.num_top_predictions)
    

    【讨论】:

      【解决方案2】:

      试试下面的,

      import os
      import tensorflow as tf
      
      # Define this after your imports. This is similar to python argparse except more verbose
      FLAGS = tf.app.flags.FLAGS
      
      tf.app.flags.DEFINE_string('image', '/Users/photos',
                                 """
                                 Define your 'image' folder here 
                                 or as an argument to your script
                                 for eg, test.py --image /Users/..
                                 """)
      
      
      
      # use listdir to list the images in the target folder
      filelist = os.listdir(FLAGS.image)
      
      # now iterate over the objects in the list
      for i in filelist:
          # load image
          image_data = load_image(i)
      

      这应该可以。希望对您有所帮助。

      【讨论】:

        【解决方案3】:

        试试tf.flags.FLAGS,或者在顶部from tf.flags import FLAGS

        【讨论】:

          猜你喜欢
          • 2016-10-16
          • 2022-12-06
          • 2018-04-07
          • 1970-01-01
          • 2020-02-09
          • 2018-11-07
          • 1970-01-01
          • 1970-01-01
          • 2020-10-25
          相关资源
          最近更新 更多