【问题标题】:Batch prediction using a trained Object Detection APIs model and TF 2使用经过训练的对象检测 API 模型和 TF 2 进行批量预测
【发布时间】:2020-09-02 09:40:50
【问题描述】:

我在 TPU 上使用 TF 2 的对象检测 API 成功训练了一个模型,该模型保存为 .pb(SavedModel 格式)。然后我使用tf.saved_model.load 将其加载回来,当使用转换为形状为(1, w, h, 3) 的张量的单个图像预测框时,它可以正常工作。

import tensorflow as tf
import numpy as np

# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')

image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image, channels=3).numpy()
input_tensor = np.expand_dims(image_np, 0)
detections = detect_fn(input_tensor) # This works fine

问题是我需要进行批量预测以将其扩展到 50 万张图像,但该模型的输入签名似乎仅限于处理形状为 (1, w, h, 3) 的数据。 这也意味着我不能在 Tensorflow Serving 中使用批处理。 我怎么解决这个问题?我可以只更改模型签名来处理批量数据吗?

所有工作(加载模型 + 预测)均在使用 Object Detection API 发布的官方容器内执行(来自here

【问题讨论】:

    标签: tensorflow object batch-processing prediction detection


    【解决方案1】:

    我最近遇到了这个问题。当您使用exporter_main_v2.py 将检查点文件转换为.pb 文件时,它将调用exporter_lib_v2.py。我发现在文件exporter_lib_v2.py (here) 中,TF2 硬固定了形状为[1, None, None, 3] 的输入签名。我们得把它改成[None, None, None, 3]

    需要将该文件中的这些行(138162170185)从 1 修改为 None。然后重建 TF2 Object Detector API Repo (link) 并使用新构建的版本再次导出.pb

    【讨论】:

      猜你喜欢
      • 2020-07-04
      • 1970-01-01
      • 1970-01-01
      • 2018-06-11
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-08-19
      • 1970-01-01
      相关资源
      最近更新 更多