【问题标题】:tensorflow.keras save model in python and loading in Javatensorflow.keras 在 python 中保存模型并在 Java 中加载
【发布时间】:2019-01-28 23:39:47
【问题描述】:

我有一个经过微调的 vgg 模型,我使用 tensorflow.keras 功能 API 创建了模型,并使用 tf.contrib.saved_model.save_keras_model 保存了模型。 因此模型以这种结构保存:包含 saved_model.json 文件、saved_model.pb 文件的 assets 文件夹和包含 checkpoint 的 variables 文件夹strong>、variables.data-00000-of-00001variables.index

我可以轻松地在 python 中加载我的模型并使用 tf.contrib.saved_model.load_keras_model(saved_model_path) 获得预测,但我不知道如何在 JAVA 中加载模型。我google了很多,发现这个How to export Keras .h5 to tensorflow .pb?导出为pb文件,然后按照这个链接Loading in Java加载它。我无法冻结图形,并且我尝试使用 simple_save,但 tensorflow.keras 不支持 simple_save(AttributeError: module 'tensorflow.contrib.saved_model' 没有属性 'simple_save')。那么有人可以帮我弄清楚在JAVA中加载我的模型(tensorflow.keras功能API模型)需要哪些步骤。

我拥有的 saved_model.pb 文件是否足以在 JAVA 端加载?我需要创建输入/输出占位符吗?那怎么导出呢?
感谢您的帮助。

【问题讨论】:

标签: java python tensorflow keras


【解决方案1】:

如果您有一个以 SavedModel 格式保存的模型(您似乎已经这样做了,tf.contrib.saved_model.save_keras_model 之类的东西可以帮助创建),那么在 Java 中您可以使用 SavedModelBundle.load 来加载和提供它。您不需要“冻结”模型。

您可能会发现以下有用:

但基本思想是您的代码将类似于:

try (SavedModelBundle model = SavedModelBundle.load("<directory>", "serve")) {
  try (Tensor<?> input = makeInputTensor();
       Tensor<?> output = model.session().runner().feed("INPUT_TENSOR", input).fetch("OUTPUT_TENSOR", output).run().get(0)) {
  // Use output
  }
}

其中"INPUT_TENSOR""OUTPUT_TENSOR" 是TensorFlow 图中输入和输出节点的名称。安装 TensorFlow for Python 时安装的 saved_model_cli 命令行工具可以显示模型中这些张量的名称。

请注意,使用 TensorFlow Java API 可能比其他评论者建议的使用 TensorFlow Lite 更适合服务器/桌面应用程序。这是因为 TensorFLow Lite 运行时虽然针对小型设备进行了优化(在内存占用等方面),但还不能导出所有模型。而 TensorFlow Java API 使用完全相同的运行时,因此具有与 TensorFlow for Python 完全相同的功能。

希望对您有所帮助。

【讨论】:

  • 我做了完全相同的事情,当我使用 Inception 模型作为预训练模型时,它最终运行良好。但是当我使用 VGG 模型作为基础时,我的模型无法在 JAVA 中加载。您是否看过任何加载 VGG 模型并对其进行微调然后在 JAVA 中加载的教程?我会将尝试 VGG 模型时遇到的错误发送给您。非常感谢您的帮助。
  • 这是我得到的错误:Matrix size-incompatible: In[0]: [1,8192], In[1]: [25088,256] [[{{node dense/MatMul }} = MatMul[T=DT_FLOAT, _output_shapes=[[?,256]], transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"] (展平/重塑,密集/MatMul/ReadVariableOp)]]
  • 响应中的示例链接不再有效。
猜你喜欢
  • 2020-05-09
  • 2022-12-03
  • 2014-10-13
  • 2013-01-23
  • 2019-06-11
  • 1970-01-01
  • 1970-01-01
  • 2021-10-11
相关资源
最近更新 更多