【发布时间】:2018-11-08 10:54:44
【问题描述】:
如果标题不能完全反映我的问题(我认为确实如此,但我不确定),我很抱歉,我将在下面描述。
我正在努力将 Yolo 对象检测模型转换为 TensorFlow 冻结模型 .pb,然后使用该模型在手机上进行预测。
我已经成功获得了一个有效的.pb 模型(即来自 Yolo 图表的冻结图)。但是由于网络的输出(其中有两个)不是边界框,所以我必须编写一个转换函数(这部分不是我的问题,我已经有了这个任务的工作函数):
def get_boxes_from_output(outputs_of_the_graph, anchors,
num_classes, input_image_shape,
score_threshold=score, iou_threshold=iou)
"""
Apply some operations on the outputs_of_the_graph to obtain bounding boxes information
"""
return boxes, scores, classes
所以管道很简单:我必须加载pb模型,然后将图像数据扔给它以获得两个输出,然后从这两个输出中,我应用上面的函数(包含张量操作)来获得边界框信息。代码如下所示:
model_path = 'model_data/yolo.pb'
class_names = _get_class('model_data/classes.txt')
anchors = _get_anchors('model_data/yolo_anchors.txt')
score = 0.25
iou = 0.5
# Load the Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(model_path, 'rb') as fid:
graph_def.ParseFromString(fid.read())
tf.import_graph_def(graph_def, name='')
# Get the input and output nodes (there are two outputs)
l_input = detection_graph.get_tensor_by_name('input_1:0')
l_output = [detection_graph.get_tensor_by_name('conv2d_10/BiasAdd:0'),
detection_graph.get_tensor_by_name('conv2d_13/BiasAdd:0')]
#initialize_all_variables
tf.global_variables_initializer()
# Generate output tensor targets for filtered bounding boxes.
input_image_shape = tf.placeholder(dtype=tf.float32,shape=(2, ))
training = tf.placeholder(tf.bool, name='training')
boxes, scores, classes = get_boxes_from_output(l_output, anchors,
len(class_names), input_image_shape,
score_threshold=score, iou_threshold=iou)
image = Image.open('./data/image1.jpg')
image = preprocess_image(image)
image_data = np.array(image, dtype='float32')
image_data = np.expand_dims(image_data, 0) # Add batch dimension.
sess = tf.Session(graph=detection_graph)
# Run the session to get the output bounding boxes
out_boxes, out_scores, out_classes = sess.run(
[boxes, scores, classes],
feed_dict={
l_input: image_data,
input_image_shape: [image.size[1], image.size[0]],
training: False
})
# Now how do I save a new model that outputs directly [boxes, scores, classes]
现在我的问题是如何从会话中保存一个新的.pb 模型,以便我可以在其他地方再次加载它并直接输出boxes, scores, classes?
我希望问题足够清楚。
非常感谢您的帮助!
【问题讨论】:
-
请提供您的
get_boxes_from_output功能。我正在尝试做类似的事情,但在理解模型的输出时遇到了问题。
标签: tensorflow