【问题标题】:How to visualize detected boxes from TFLite model (How to get category index from TFLite model?)如何从 TFLite 模型中可视化检测到的框(如何从 TFLite 模型中获取类别索引?)
【发布时间】:2021-04-28 17:49:13
【问题描述】:

我有一个对象检测 TFLite 模型保存为 model.tflite 文件。我可以运行它

interpreter = tf.lite.Interpreter("model.tflite")

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_image)

interpreter.invoke()

然后得到输出为

detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])

我想在图片中绘制具有给定类别的检测框。最简单的解决方案似乎是使用工具viz_utils.visualize_boxes_and_labels_on_image_array as。

viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        detection_boxes,
        detection_classes,
        detection_scores,
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=20,
        min_score_thresh=.1,
        agnostic_mode=False

但是,为此需要category_index(将类索引转换为人类可读的标签)。通常,您可以从包含标签的文件中加载它,如果我没记错的话,如果是 .tflite 模型,应该将其包含/打包在 .tflite 文件中。

但是,我不知道该怎么做,或者我应该使用哪些函数(我还查看了 tflite_support 库,但不知道如何从关联文件中提取类别)。

使用 .tflite 文件可视化检测到的带有标签的框的正确方法是什么?它不必使用viz_utils。任何帮助表示赞赏。谢谢。

【问题讨论】:

  • @JaesungChung 在那个例子中,他们指的是 tensorflow,而不是 TFLite,但即便如此,我的问题是我没有 path_to_labels(来自那个例子)。我只有打包的 tflite 模型文件 - 我认为标签应该可以从该文件中提取。但是,也许这就是我错的地方。
  • for idx, box in enumerate(detection_boxes'][0]): if output_dict['detection_scores'][0][idx] > treshold: class_name = category_index[int(output_dict['detection_classes' ][0][idx])]['name']

标签: tensorflow tensorflow-lite


【解决方案1】:
# labels variable contains the list of the names of the category and
# it generates by reading the labels.txt
with open("labels.txt", "r") as f:
  txt = f.read()

labels = txt.splitlines()

for idx, box in enumerate(detection_boxes[0]):
    if detection_scores[0][idx] > threshold:
        class_name = labels[int(detection_classes[0][idx])]

我根据https://github.com/tensorflow/models/issues/7458#issuecomment-523904465创建了这段代码sn-p。

【讨论】:

  • 如果你只有 mode.tflite 和 dict.txt,你从哪里得到category_index
  • 你是说我在detection_classes中得到的标签是由labels.txt中的行号索引的吗?
  • 是的,detection_classes 存储检测到的类索引。
  • 我想你误解了我在问什么。但无论如何,你似乎是对的。标签存储在 dict.txt 中的顺序对应于从模型返回的类 ID。很有意思。我在任何文档中都找不到它。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-07-08
  • 2022-01-05
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2022-08-08
相关资源
最近更新 更多