【问题标题】:How to convert TensorFlow checkpoint files to TensorFlowJS?如何将 TensorFlow 检查点文件转换为 TensorFlowJS?
【发布时间】:2021-05-05 15:24:55
【问题描述】:

我认为我有一个在 TensorFlow v1 上开发的项目。它在 Python 3.8 中像这样工作:

 ...
 saver = tf.train.Saver(var_list=vars)
 ...
 saver.restore(self.sess, tf.train.latest_checkpoint(checkpoint_dir))
 ...

检查点文件位于“checkpoint_dir”中

我想将它与 TFjs 一起使用,但我不知道如何将检查点文件转换为可以使用 TFjs 加载的文件。

我该怎么办?

谢谢,

约翰

【问题讨论】:

  • 模型的格式是什么?
  • 我如何找到它?我刚从 github 上得到项目,它包含代码和检查点文件
  • 检查文件类型是不是model.h5
  • 不,有一个文件 'checkpoint' 和另外 2 个文件 'model-NNN.data0000-of-0001' 和 model-NNN.index' 猜测这些是中间训练检查点
  • 也许我在这里遗漏了一些东西,但是有没有办法在 tensorflowJS 中使用 tf.v1 模型 + 检查点?

标签: python tensorflow tensorflow.js tensorflowjs-converter tensorflow1.15


【解决方案1】:

好的,我想通了。希望这对像我这样的其他初学者也有帮助。

检查点文件不包含模型,它们只包含模型的值(权重等)。

模型实际上是在代码中构建的。因此,以下是将 Tensorflow v1 检查点文件转换为 TensorflowJS 可加载模型的步骤:

  1. 首先我再次保存了检查点,因为缺少一个文件(.meta 文件),其中包含有关检查点中值的一些元信息。为了使用 meta 保存检查点,我在 saver.restore(... 调用之后立即使用了此代码,如下所示:
...
saver.save(self.sess,save_path='./newcheckpoint/')
...
  1. 将模型另存为冻结模型文件,如下所示:
import tensorflow.compat.v1 as tf

meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./freeze/output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

这会将模型保存到./freeze/output_graph.pb

  1. 使用 tensorflowjs_converter 将冻结模型转换为 Web 模型,如下所示:

tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='final_add' --skip_op_check ./freeze/output_graph.pb ./web_model/

由于在尝试转换时缺少一些操作错误/警告,不得不使用 --skip_op_check

作为第 3 步的结果,./webmodel/ 文件夹将包含 TensorflowJS 库所需的 JSON 和二进制文件。

这是我使用 tfjs 2.x 加载模型的方式:

model=await tf.loadGraphModel('web_model/model.json');

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2018-10-08
    • 2020-01-15
    • 1970-01-01
    • 2020-08-04
    • 2023-02-25
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多