【问题标题】:TensorFlow use dataset to replace function feed_dictTensorFlow 使用数据集替换函数 feed_dict
【发布时间】:2018-10-23 00:10:09
【问题描述】:

当我学习一个tensorflow项目时,找一行代码:

cls_prob, box_pred = sess.run([output_cls_prob, output_box_pred], feed_dict={input_img: blob})

但是,这行代码花了很多时间。 (使用CPU需要15秒...┭┮﹏┭┮)

通过查阅资料,我发现使用函数'dataset'可以解决这个花费了很多时间的问题,我应该如何使用它?

“blob”的来源:

img = cv2.imread('./imgs/001.jpg')
img_scale = float(600) / min(img_data.shape[0], img_data.shape[1])
if np.round(img_scale * max(img_data.shape[0], img_data.shape[1])) > 1200:
    img_scale = float(1200) / max(img_data.shape[0], img_data.shape[1])
img_data = cv2.resize(img_data, None, None, fx=img_scale, fy=img_scale, interpolation=cv2.INTER_LINEAR)
img_orig = img_data.astype(np.float32, copy=True)
blob = np.zeros((1, img_data.shape[0], img_data.shape[1], 3),dtype=np.float32)
blob[0, 0:img_data.shape[0], 0:img_data.shape[1], :] = img_orig

'output_cls_prob'&'output_box_pred'&'input_img'的来源:

# Actually,read PB model...
input_img = sess.graph.get_tensor_by_name('Placeholder:0')
output_cls_prob = sess.graph.get_tensor_by_name('Reshape_2:0')
output_box_pred = sess.graph.get_tensor_by_name('rpn_bbox_pred/Reshape_1:0')

参数类型:

blob:type 'numpy.ndarray'

output_cls_prob:class 'tensorflow.python.framework.ops.Tensor'

output_box_pred:class 'tensorflow.python.framework.ops.Tensor'

input_img:class 'tensorflow.python.framework.ops.Tensor'

【问题讨论】:

  • tf.data 是 tensorflow 输入管道的推荐 API。这是 tensorflow.org 上的教程:tensorflow.org/guide/datasets。如果您提供有关如何获取值 blob 的更多信息,stackoverflow 上的人们可能能够提供更具体的代码 sn-p 来说明在这种情况下如何使用 tf.data
  • 谢谢提醒,代码一直在补充。

标签: python tensorflow


【解决方案1】:

tf.data 是 tensorflow 输入管道的推荐 API。这是tensorflow.org 的教程。对于您的示例,"Decoding image data and resizing it" 部分可能是最有用的。例如,您可以执行以下操作:

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string)
  image_resized = tf.image.resize_images(image_decoded, [new_width, new_height])
  image_resized = tf.expand_dims(image_resized, 0)  # Adds size 1 dimension
  return image_resized

# A vector of filenames.
filenames = tf.constant(["./imgs/001.jpg", ...])

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(_parse_function)

不要让input_img 成为占位符,而是更改:

input_img = tf.placeholder(tf.float32)
output_class_prob, output_class_pred = (... use input_img ...)

到:

iterator = dataset.make_one_shot_iterator()
input_img = iterator.get_next()
output_class_prob, output_class_pred = (... use input_img ...)

【讨论】:

  • 呃……我只有一张图片。但是图片太大了。所以,模型的计算时间太长了。呃..嗯……我错了吗?跨度>
  • 如果您不以全分辨率使用图像,是否需要对文件进行预处理以使其更小?
【解决方案2】:

首先你应该知道,在使用多个 GPU 时,使用 Dataset API 对性能的影响很大……否则几乎与 feed_dict 相同。我建议您阅读来自 TF 开发人员的 this other answer,它几乎包含了人们需要知道的所有内容,以便在脑海中想象这个新 API 的好处。

【讨论】:

  • 数据 API 的好处在分布式设置中更为明显,但它也会对单 GPU 环境产生重大影响,因为它减少了 Python 与 TensorFlow 运行时和 GPU 之间的内存交换,它可以在 GPU 处理过程中预取数据,并且可以并行读取数据。而且它对于内存也无法容纳的大型数据集非常有帮助。
  • 谢谢,但是'blob'真的很大,我认为它占用了很多内存。
猜你喜欢
  • 2019-12-10
  • 2017-08-02
  • 1970-01-01
  • 1970-01-01
  • 2018-09-20
  • 2018-11-19
  • 1970-01-01
  • 2017-12-21
  • 1970-01-01
相关资源
最近更新 更多