【问题标题】:How to load a tflite model in script?如何在脚本中加载 tflite 模型?
【发布时间】:2018-10-30 18:56:44
【问题描述】:

我已使用 bazel.pb 文件转换为 tflite 文件。现在我想在我的 python 脚本中加载这个tflite 模型只是为了测试天气这是否给了我正确的输出?

【问题讨论】:

    标签: python tensorflow tensorflow-lite


    【解决方案1】:

    在 Python 中使用 TensorFlow lite 模型:

    TensorFlow Lite 的冗长功能强大,因为它允许您进行更多控制,但在许多情况下您只想传递输入并获得输出,因此我制作了一个包装此逻辑的类:

    以下适用于 tfhub.dev 中的分类模型,例如:https://tfhub.dev/tensorflow/lite-model/mobilenet_v2_1.0_224/1/metadata/1

    # Usage
    model = TensorflowLiteClassificationModel("path/to/model.tflite")
    (label, probability) = model.run_from_filepath("path/to/image.jpeg")
    
    import tensorflow as tf
    import numpy as np
    from PIL import Image
    
    
    class TensorflowLiteClassificationModel:
        def __init__(self, model_path, labels, image_size=224):
            self.interpreter = tf.lite.Interpreter(model_path=model_path)
            self.interpreter.allocate_tensors()
            self._input_details = self.interpreter.get_input_details()
            self._output_details = self.interpreter.get_output_details()
            self.labels = labels
            self.image_size=image_size
    
        def run_from_filepath(self, image_path):
            input_data_type = self._input_details[0]["dtype"]
            image = np.array(Image.open(image_path).resize((self.image_size, self.image_size)), dtype=input_data_type)
            if input_data_type == np.float32:
                image = image / 255.
    
            if image.shape == (1, 224, 224):
                image = np.stack(image*3, axis=0)
    
            return self.run(image)
    
        def run(self, image):
            """
            args:
              image: a (1, image_size, image_size, 3) np.array
    
            Returns list of [Label, Probability], of type List<str, float>
            """
    
            self.interpreter.set_tensor(self._input_details[0]["index"], image)
            self.interpreter.invoke()
            tflite_interpreter_output = self.interpreter.get_tensor(self._output_details[0]["index"])
            probabilities = np.array(tflite_interpreter_output[0])
    
            # create list of ["label", probability], ordered descending probability
            label_to_probabilities = []
            for i, probability in enumerate(probabilities):
                label_to_probabilities.append([self.labels[i], float(probability)])
            return sorted(label_to_probabilities, key=lambda element: element[1])
    

    注意

    但是,您需要对其进行修改以支持不同的用例,因为我将图像作为输入传递,并获得分类([标签,概率])输出。如果您需要文本输入 (NLP) 或其他输出(对象检测输出边界框、标签和概率)、分类(仅标签)等。

    此外,如果您希望输入不同大小的图像,那么您必须更改输入大小并重新分配模型 (self.interpreter.allocate_tensors())。这很慢(效率低下)。最好使用平台大小调整功能(例如 Android 图形库)而不是使用 TensorFlow lite 模型来进行大小调整。或者,您可以使用单独的模型来调整模型的大小,这样allocate_tensors() 会更快。

    【讨论】:

      【解决方案2】:

      您可以使用 TensorFlow Lite Python 解释器在 python shell 中加载 tflite 模型,并使用您的输入数据对其进行测试。

      代码会是这样的:

      import numpy as np
      import tensorflow as tf
      
      # Load TFLite model and allocate tensors.
      interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
      interpreter.allocate_tensors()
      
      # Get input and output tensors.
      input_details = interpreter.get_input_details()
      output_details = interpreter.get_output_details()
      
      # Test model on random input data.
      input_shape = input_details[0]['shape']
      input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
      interpreter.set_tensor(input_details[0]['index'], input_data)
      
      interpreter.invoke()
      
      # The function `get_tensor()` returns a copy of the tensor data.
      # Use `tensor()` in order to get a pointer to the tensor.
      output_data = interpreter.get_tensor(output_details[0]['index'])
      print(output_data)
      

      以上代码来自TensorFlow Lite官方指南更多详细信息,请阅读this

      【讨论】:

      • 使用了哪个 tensorflow 版本?口译员现在不在场。
      • 正如我刚刚使用 tensorflow 1.14.0 测试的那样,tflite Interpreter 已从 tf.contrib.lite.Interpreter 移至 tf.lite.Interpreter,请参阅上面的更新答案。
      • 这真的很棒。我修改了文件以实际测试图像,我发现我的 .tflite 文件一定是无效的。如果你熟悉对象检测,可以看看stackoverflow.com/questions/59736600/…吗?
      • 如何在测试数据上测试而不是随机数据
      • 我们如何对所有数据集进行预测?像“.predict(x_test)”?
      猜你喜欢
      • 2022-01-16
      • 2021-10-06
      • 2021-05-04
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2013-10-08
      • 2023-03-21
      相关资源
      最近更新 更多