【发布时间】:2019-07-16 18:30:40
【问题描述】:
我在将 tensorflow 模型转换为 tflite 模型时尝试使用 UINT8 量化:
如果使用post_training_quantize = True,模型大小比原始 fp32 模型小 x4,所以我假设模型权重是 uint8,但是当我加载模型并通过 interpreter_aligner.get_input_details()[0]['dtype'] 获取输入类型时,它是 float32。量化模型的输出与原始模型大致相同。
converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
graph_def_file='tflite-models/tf_model.pb',
input_arrays=input_node_names,
output_arrays=output_node_names)
converter.post_training_quantize = True
tflite_model = converter.convert()
转换模型的输入/输出:
print(interpreter_aligner.get_input_details())
print(interpreter_aligner.get_output_details())
[{'name': 'input_1_1', 'index': 47, 'shape': array([ 1, 128, 128, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[{'name': 'global_average_pooling2d_1_1/Mean', 'index': 45, 'shape': array([ 1, 156], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
另一种选择是明确指定更多参数: 模型大小比原始 fp32 模型小 x4,模型输入类型为 uint8,但模型输出更像垃圾。
converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
graph_def_file='tflite-models/tf_model.pb',
input_arrays=input_node_names,
output_arrays=output_node_names)
converter.post_training_quantize = True
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
converter.quantized_input_stats = {input_node_names[0]: (0.0, 255.0)} # (mean, stddev)
converter.default_ranges_stats = (-100, +100)
tflite_model = converter.convert()
转换模型的输入/输出:
[{'name': 'input_1_1', 'index': 47, 'shape': array([ 1, 128, 128, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003921568859368563, 0)}]
[{'name': 'global_average_pooling2d_1_1/Mean', 'index': 45, 'shape': array([ 1, 156], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.7843137383460999, 128)}]
所以我的问题是:
- 仅设置
post_training_quantize = True时会发生什么?即为什么第一种情况可以正常工作,但第二种情况不行。 - 如何估计第二种情况的均值、标准差和范围参数?
- 看起来在第二种情况下模型推理更快,这取决于模型输入是 uint8 的事实吗?
- 第一种情况下的
'quantization': (0.0, 0)和第二种情况下的'quantization': (0.003921568859368563, 0),'quantization': (0.7843137383460999, 128)是什么意思? -
converter.default_ranges_stats是什么?
更新:
问题4的答案找到What does 'quantization' mean in interpreter.get_input_details()?
【问题讨论】:
-
@suharshs 看起来你与 tensorflow 的这一部分有关,你能详细说明一下吗?
-
4a。对于 float32 的 dtype,量化 被忽略
标签: python tensorflow deep-learning tensorflow-lite quantization