【发布时间】:2021-05-31 13:48:44
【问题描述】:
错误: TF Lite 转换器在尝试转换时态 CNN(使用广泛使用的 Keras TCN 库构建:https://github.com/philipperemy/keras-tcn)时抛出未跟踪函数警告,并在以下情况下抛出模型解析错误尝试做训练后量化
1。系统信息
- 操作系统平台和发行版(例如,Linux Ubuntu 16.04):Ubuntu 18.04
- TensorFlow 安装(pip 包或从源代码构建):Pip (python 3.8.8)
- TensorFlow 库(版本,如果 pip 包或 github SHA,如果从源代码构建):2.3.0(TF Base)、2.4.0(TF-GPU)
2。代码
第 1 部分,将预训练的 TF 模型转换为 TF Lite 模型:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
from tensorflow.python.keras.backend import set_session
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU
config.log_device_placement = True # to log device placement (on which device the operation ran)
sess = tf.compat.v1.Session(config=config)
set_session(sess) # set this TensorFlow session as the default
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import numpy as np
from gtda.time_series import SlidingWindow
import matplotlib.pyplot as plt
from math import atan2, pi, sqrt
from tensorflow.keras.layers import Dense, MaxPooling1D, Flatten
from tensorflow.keras import Input, Model
from tensorflow.keras.callbacks import ModelCheckpoint
from tcn import TCN, tcn_full_summary
from tensorflow.keras.models import load_model
model = load_model('best_joint_new.hdf5',custom_objects={'TCN':TCN})
converter = tf.lite.TFLiteConverter.from_keras_model(model)
model_no_quant_tflite = converter.convert()
open('best_joint.tflite', "wb").write(model_no_quant_tflite)
第 2 部分:训练后量化
trainX 是一个 nX200X6 浮点值矩阵,n 可以是任意整数。
def representative_dataset():
for i in range(trainX.shape[0]):
yield ([trainX[i]])
使用与以前相同的转换器。
# Set the optimization flag.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Enforce integer only quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
# Provide a representative dataset to ensure we quantize correctly.
converter.representative_dataset = representative_dataset
model_quant_tflite = converter.convert()
# Save the model to disk
open('best_joint_quant.tflite', "wb").write(model_quant_tflite)
3。转换后失败
第 1 部分(转换成功但产生以下警告)
WARNING:absl:Found untraced functions such as residual_block_0_layer_call_and_return_conditional_losses, residual_block_0_layer_call_fn, residual_block_1_layer_call_and_return_conditional_losses, residual_block_1_layer_call_fn, residual_block_2_layer_call_and_return_conditional_losses while saving (showing 5 of 325). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as residual_block_0_layer_call_and_return_conditional_losses, residual_block_0_layer_call_fn, residual_block_1_layer_call_and_return_conditional_losses, residual_block_1_layer_call_fn, residual_block_2_layer_call_and_return_conditional_losses while saving (showing 5 of 325). These functions will not be directly callable after loading.
第 2 部分(量化失败)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
~/anaconda3/envs/tflite/lib/python3.8/site-packages/tensorflow/lite/python/optimize/calibrator.py in __init__(self, model_content)
57 self._calibrator = (
---> 58 _calibration_wrapper.CalibrationWrapper(model_content))
59 except Exception as e:
TypeError: pybind11::init(): factory function returned nullptr
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-7-f34a9c965790> in <module>
7 # Provide a representative dataset to ensure we quantize correctly.
8 converter.representative_dataset = representative_dataset
----> 9 model_quant_tflite = converter.convert()
10 # Save the model to disk
11 open('best_joint_quant.tflite', "wb").write(model_quant_tflite)
~/anaconda3/envs/tflite/lib/python3.8/site-packages/tensorflow/lite/python/lite.py in convert(self)
871 graph=frozen_func.graph)
872
--> 873 return super(TFLiteKerasModelConverterV2,
874 self).convert(graph_def, input_tensors, output_tensors)
875
~/anaconda3/envs/tflite/lib/python3.8/site-packages/tensorflow/lite/python/lite.py in convert(self, graph_def, input_tensors, output_tensors)
630 calibrate_and_quantize, flags = quant_mode.quantizer_flags()
631 if calibrate_and_quantize:
--> 632 result = self._calibrate_quantize_model(result, **flags)
633
634 flags_modify_model_io_type = quant_mode.flags_modify_model_io_type(
~/anaconda3/envs/tflite/lib/python3.8/site-packages/tensorflow/lite/python/lite.py in _calibrate_quantize_model(self, result, inference_input_type, inference_output_type, activations_type, allow_float)
447 # Add intermediate tensors to the model if needed.
448 result = _calibrator.add_intermediate_tensors(result)
--> 449 calibrate_quantize = _calibrator.Calibrator(result)
450 if self._experimental_calibrate_only or self._experimental_new_quantizer:
451 calibrated = calibrate_quantize.calibrate(
~/anaconda3/envs/tflite/lib/python3.8/site-packages/tensorflow/lite/python/optimize/calibrator.py in __init__(self, model_content)
58 _calibration_wrapper.CalibrationWrapper(model_content))
59 except Exception as e:
---> 60 raise ValueError("Failed to parse the model: %s." % e)
61 if not self._calibrator:
62 raise ValueError("Failed to parse the model.")
ValueError: Failed to parse the model: pybind11::init(): factory function returned nullptr.
4。主模型架构
输入:6 通道 200 个样本的浮点值 输出:2个标量(浮点)
5。链接到原始 HDF5 模型:
https://drive.google.com/file/d/1GFRgMUkIVatSsUWgnzee_jjeF6D3l-A-/view?usp=sharing
【问题讨论】:
标签: python-3.x tensorflow keras deep-learning conv-neural-network