前言

TensorFlow Lite 提供了转换 TensorFlow 模型,并在移动端(mobile)、嵌入式(embeded)和物联网(IoT)设备上运行 TensorFlow 模型所需的所有工具。之前想部署tensorflow模型,需要转换成tflite模型。

实现过程

1.不同模型的调用函数接口稍微有些不同

# Converting a SavedModel to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()

# Converting a tf.Keras model to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Converting ConcreteFunctions to a TensorFlow Lite model.
converter = lite.TFLiteConverter.from_concrete_functions([func])
tflite_model = converter.convert()

2. 完整的实现

import tensorflow as tf
saved_model_dir = './mobilenet/'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.experimental_new_converter = True
tflite_model = converter.convert()
open('model_tflite.tflite', 'wb').write(tflite_model)

其中,

@classmethod
from_saved_model(
    cls,
    saved_model_dir,
    signature_keys=None,
    tags=None
)

另外

For more complex SavedModels, the optional parameters that can be passed into TFLiteConverter.from_saved_model() are input_arrays, input_shapes, output_arrays, tag_set and signature_key. Details of each parameter are available by running help(tf.lite.TFLiteConverter).

对于如何查看模型的操作op,可查看here;

help(tf.lite.TFLiteConverter)结果

Help on class TFLiteConverterV2 in module tensorflow.lite.python.lite:

class TFLiteConverterV2(TFLiteConverterBase)
 |  TFLiteConverterV2(funcs, trackable_obj=None)
 |  
 |  Converts a TensorFlow model into TensorFlow Lite model.
 |  
 |  Attributes:
 |    allow_custom_ops: Boolean indicating whether to allow custom operations.
 |      When false any unknown operation is an error. When true, custom ops are
 |      created for any op that is unknown. The developer will need to provide
 |      these to the TensorFlow Lite runtime with a custom resolver.
 |      (default False)
 |    target_spec: Experimental flag, subject to change. Specification of target
 |      device.
 |    optimizations: Experimental flag, subject to change. A list of optimizations
 |      to apply when converting the model. E.g. `[Optimize.DEFAULT]
 |    representative_dataset: A representative dataset that can be used to
 |      generate input and output samples for the model. The converter can use the
 |      dataset to evaluate different optimizations.
 |    experimental_enable_mlir_converter: Experimental flag, subject to change.
 |      Enables the MLIR converter instead of the TOCO converter.
 |  
 |  Example usage:
 |  
 |    ```python
 |    # Converting a SavedModel to a TensorFlow Lite model.
 |    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
 |    tflite_model = converter.convert()
 |  
 |    # Converting a tf.Keras model to a TensorFlow Lite model.
 |    converter = lite.TFLiteConverter.from_keras_model(model)
 |    tflite_model = converter.convert()
 |  
 |    # Converting ConcreteFunctions to a TensorFlow Lite model.
 |    converter = lite.TFLiteConverter.from_concrete_functions([func])
 |    tflite_model = converter.convert()
 |    ```
 |  
 |  Method resolution order:
 |      TFLiteConverterV2
 |      TFLiteConverterBase
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, funcs, trackable_obj=None)
 |      Constructor for TFLiteConverter.
 |      
 |      Args:
 |        funcs: List of TensorFlow ConcreteFunctions. The list should not contain
 |          duplicate elements.
 |        trackable_obj: tf.AutoTrackable object associated with `funcs`. A
 |          reference to this object needs to be maintained so that Variables do not
 |          get garbage collected since functions have a weak reference to
 |          Variables. This is only required when the tf.AutoTrackable object is not
 |          maintained by the user (e.g. `from_saved_model`).
 |  
 |  convert(self)
 |      Converts a TensorFlow GraphDef based on instance variables.
 |      
 |      Returns:
 |        The converted data in serialized format.
 |      
 |      Raises:
 |        ValueError:
 |          Multiple concrete functions are specified.
 |          Input shape is not specified.
 |          Invalid quantization parameters.
 |  
 |  ----------------------------------------------------------------------
 |  Class methods defined here:
 |  
 |  from_concrete_functions(funcs) from builtins.type
 |      Creates a TFLiteConverter object from ConcreteFunctions.
 |      
 |      Args:
 |        funcs: List of TensorFlow ConcreteFunctions. The list should not contain
 |          duplicate elements.
 |      
 |      Returns:
 |        TFLiteConverter object.
 |      
 |      Raises:
 |        Invalid input type.
 |  
 |  from_keras_model(model) from builtins.type
 |      Creates a TFLiteConverter object from a Keras model.
 |      
 |      Args:
 |        model: tf.Keras.Model
 |      
 |      Returns:
 |        TFLiteConverter object.
 |  
 |  from_saved_model(saved_model_dir, signature_keys=None, tags=None) from builtins.type
 |      Creates a TFLiteConverter object from a SavedModel directory.
 |      
 |      Args:
 |        saved_model_dir: SavedModel directory to convert.
 |        signature_keys: List of keys identifying SignatureDef containing inputs
 |          and outputs. Elements should not be duplicated. By default the
 |          `signatures` attribute of the MetaGraphdef is used. (default
 |          saved_model.signatures)
 |        tags: Set of tags identifying the MetaGraphDef within the SavedModel to
 |          analyze. All tags in the tag set must be present. (default set(SERVING))
 |      
 |      Returns:
 |        TFLiteConverter object.
 |      
 |      Raises:
 |        Invalid signature keys.
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from TFLiteConverterBase:
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
View Code

相关文章: