【发布时间】:2017-03-09 22:44:16
【问题描述】:
我已经成功浏览了official tutorial,它解释了如何重新训练 inception-v3 模型,后来成功地重新训练了相同的模型 o 为特定目的训练模型。
然而,与其他更简单的模型(例如 inception-v1)相比,该模型复杂且速度慢,对于某些任务来说,它的准确性已经足够好。具体来说,我想重新训练模型以在 Android 上使用它,理想情况下,速度方面的性能应该与原始 TensorFlow Android demo 相当。无论如何,我尝试从 this link 重新训练 inception-v1 模型,并在 retrain.py 中进行了以下修改:
BOTTLENECK_TENSOR_NAME = 'avgpool0/reshape:0'
BOTTLENECK_TENSOR_SIZE = 2048
MODEL_INPUT_WIDTH = 224
MODEL_INPUT_HEIGHT = 224
MODEL_INPUT_DEPTH = 3
JPEG_DATA_TENSOR_NAME = 'input'
RESIZED_INPUT_TENSOR_NAME = 'input'
与 inception v3 不同,inception v1 没有任何 decodeJpeg 或 resize 节点:
inception v3 节点:
DecodeJpeg/contents
DecodeJpeg
Cast
ExpandDims/dim
ExpandDims
ResizeBilinear/size
ResizeBilinear
...
pool_3
pool_3/_reshape/shape
pool_3/_reshape
softmax/weights
softmax/biases
softmax/logits/MatMul
softmax/logits
softmax
inception v1 节点:
input
conv2d0_w
conv2d0_b
conv2d1_w
conv2d1_b
conv2d2_w
conv2d2_b
...
softmax1_pre_activation
softmax1
avgpool0/reshape/shape
avgpool0/reshape
softmax2_pre_activation/matmul
softmax2_pre_activation
softmax2
output
output1
output2
所以我猜图像在被输入图表之前必须重新整形。
现在点击以下函数时会出现错误:
def run_bottleneck_on_image(sess, image_data, image_data_tensor,
bottleneck_tensor):
"""Runs inference on an image to extract the 'bottleneck' summary layer.
Args:
sess: Current active TensorFlow Session.
image_data: Numpy array of image data.
image_data_tensor: Input data layer in the graph.
bottleneck_tensor: Layer before the final softmax.
Returns:
Numpy array of bottleneck values.
"""
bottleneck_values = sess.run(
bottleneck_tensor,
{image_data_tensor: image_data})
bottleneck_values = np.squeeze(bottleneck_values)
return bottleneck_values
错误:
TypeError:无法将 feed_dict 键解释为张量:无法转换 对张量的操作。
我猜想在 inception v3 中传递以下节点后,必须对 inception v1 图的输入节点上的数据进行重新整形以匹配数据:
DecodeJpeg/contents
DecodeJpeg
Cast
ExpandDims/dim
ExpandDims
ResizeBilinear/size
ResizeBilinear
如果有人已经设法重新训练 inception v1 模型或知道如何重塑 inception v1 案例中的数据以匹配 inception v3,我将非常感谢任何提示或建议。
【问题讨论】:
-
你好。我现在正在尝试解决同样的问题。你找到解决办法了吗?
-
您是否设法为 v1 创建了重新训练脚本?如果是这样,请您分享,因为下面的答案不起作用。
标签: tensorflow