【发布时间】:2018-10-23 00:10:09
【问题描述】:
当我学习一个tensorflow项目时,找一行代码:
cls_prob, box_pred = sess.run([output_cls_prob, output_box_pred], feed_dict={input_img: blob})
但是,这行代码花了很多时间。 (使用CPU需要15秒...┭┮﹏┭┮)
通过查阅资料,我发现使用函数'dataset'可以解决这个花费了很多时间的问题,我应该如何使用它?
“blob”的来源:
img = cv2.imread('./imgs/001.jpg')
img_scale = float(600) / min(img_data.shape[0], img_data.shape[1])
if np.round(img_scale * max(img_data.shape[0], img_data.shape[1])) > 1200:
img_scale = float(1200) / max(img_data.shape[0], img_data.shape[1])
img_data = cv2.resize(img_data, None, None, fx=img_scale, fy=img_scale, interpolation=cv2.INTER_LINEAR)
img_orig = img_data.astype(np.float32, copy=True)
blob = np.zeros((1, img_data.shape[0], img_data.shape[1], 3),dtype=np.float32)
blob[0, 0:img_data.shape[0], 0:img_data.shape[1], :] = img_orig
'output_cls_prob'&'output_box_pred'&'input_img'的来源:
# Actually,read PB model...
input_img = sess.graph.get_tensor_by_name('Placeholder:0')
output_cls_prob = sess.graph.get_tensor_by_name('Reshape_2:0')
output_box_pred = sess.graph.get_tensor_by_name('rpn_bbox_pred/Reshape_1:0')
参数类型:
blob:type 'numpy.ndarray'
output_cls_prob:class 'tensorflow.python.framework.ops.Tensor'
output_box_pred:class 'tensorflow.python.framework.ops.Tensor'
input_img:class 'tensorflow.python.framework.ops.Tensor'
【问题讨论】:
-
tf.data是 tensorflow 输入管道的推荐 API。这是 tensorflow.org 上的教程:tensorflow.org/guide/datasets。如果您提供有关如何获取值blob的更多信息,stackoverflow 上的人们可能能够提供更具体的代码 sn-p 来说明在这种情况下如何使用tf.data。 -
谢谢提醒,代码一直在补充。
标签: python tensorflow