前言
TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile)、嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具。之前想部署tensorflow模型,需要转换成tflite模型。
实现过程
1.不同模型的调用函数接口稍微有些不同
# Converting a SavedModel to a TensorFlow Lite model. converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() # Converting a tf.Keras model to a TensorFlow Lite model. converter = lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # Converting ConcreteFunctions to a TensorFlow Lite model. converter = lite.TFLiteConverter.from_concrete_functions([func]) tflite_model = converter.convert()
2. 完整的实现
import tensorflow as tf saved_model_dir = './mobilenet/' converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.experimental_new_converter = True tflite_model = converter.convert() open('model_tflite.tflite', 'wb').write(tflite_model)
其中,
@classmethod from_saved_model( cls, saved_model_dir, signature_keys=None, tags=None )
另外
For more complex SavedModels, the optional parameters that can be passed into TFLiteConverter.from_saved_model() are input_arrays, input_shapes, output_arrays, tag_set and signature_key. Details of each parameter are available by running help(tf.lite.TFLiteConverter).
对于如何查看模型的操作op,可查看here;
help(tf.lite.TFLiteConverter)结果
Help on class TFLiteConverterV2 in module tensorflow.lite.python.lite: class TFLiteConverterV2(TFLiteConverterBase) | TFLiteConverterV2(funcs, trackable_obj=None) | | Converts a TensorFlow model into TensorFlow Lite model. | | Attributes: | allow_custom_ops: Boolean indicating whether to allow custom operations. | When false any unknown operation is an error. When true, custom ops are | created for any op that is unknown. The developer will need to provide | these to the TensorFlow Lite runtime with a custom resolver. | (default False) | target_spec: Experimental flag, subject to change. Specification of target | device. | optimizations: Experimental flag, subject to change. A list of optimizations | to apply when converting the model. E.g. `[Optimize.DEFAULT] | representative_dataset: A representative dataset that can be used to | generate input and output samples for the model. The converter can use the | dataset to evaluate different optimizations. | experimental_enable_mlir_converter: Experimental flag, subject to change. | Enables the MLIR converter instead of the TOCO converter. | | Example usage: | | ```python | # Converting a SavedModel to a TensorFlow Lite model. | converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) | tflite_model = converter.convert() | | # Converting a tf.Keras model to a TensorFlow Lite model. | converter = lite.TFLiteConverter.from_keras_model(model) | tflite_model = converter.convert() | | # Converting ConcreteFunctions to a TensorFlow Lite model. | converter = lite.TFLiteConverter.from_concrete_functions([func]) | tflite_model = converter.convert() | ``` | | Method resolution order: | TFLiteConverterV2 | TFLiteConverterBase | builtins.object | | Methods defined here: | | __init__(self, funcs, trackable_obj=None) | Constructor for TFLiteConverter. | | Args: | funcs: List of TensorFlow ConcreteFunctions. The list should not contain | duplicate elements. | trackable_obj: tf.AutoTrackable object associated with `funcs`. A | reference to this object needs to be maintained so that Variables do not | get garbage collected since functions have a weak reference to | Variables. This is only required when the tf.AutoTrackable object is not | maintained by the user (e.g. `from_saved_model`). | | convert(self) | Converts a TensorFlow GraphDef based on instance variables. | | Returns: | The converted data in serialized format. | | Raises: | ValueError: | Multiple concrete functions are specified. | Input shape is not specified. | Invalid quantization parameters. | | ---------------------------------------------------------------------- | Class methods defined here: | | from_concrete_functions(funcs) from builtins.type | Creates a TFLiteConverter object from ConcreteFunctions. | | Args: | funcs: List of TensorFlow ConcreteFunctions. The list should not contain | duplicate elements. | | Returns: | TFLiteConverter object. | | Raises: | Invalid input type. | | from_keras_model(model) from builtins.type | Creates a TFLiteConverter object from a Keras model. | | Args: | model: tf.Keras.Model | | Returns: | TFLiteConverter object. | | from_saved_model(saved_model_dir, signature_keys=None, tags=None) from builtins.type | Creates a TFLiteConverter object from a SavedModel directory. | | Args: | saved_model_dir: SavedModel directory to convert. | signature_keys: List of keys identifying SignatureDef containing inputs | and outputs. Elements should not be duplicated. By default the | `signatures` attribute of the MetaGraphdef is used. (default | saved_model.signatures) | tags: Set of tags identifying the MetaGraphDef within the SavedModel to | analyze. All tags in the tag set must be present. (default set(SERVING)) | | Returns: | TFLiteConverter object. | | Raises: | Invalid signature keys. | | ---------------------------------------------------------------------- | Data descriptors inherited from TFLiteConverterBase: | __dict__ | dictionary for instance variables (if defined) | | __weakref__ | list of weak references to the object (if defined)