title: ckpt模型转pb模型手册
tags: tensorflow,ckpt,pb

为什么要把ckpt格式转为pb格式?

ckpt格式的模型图结构和权重参数模型是分开保存在不同的文件里,而pb模型的图结构和权重参数的模型是保存在同一个文件中的,很多时候我们需要一个既包含图又包含变量的模型文件

ckpt格式模型包含的文件如下:

我们在checkpoint_dir目录下保存的文件结构如下:
tensorflow: ckpt模型转成pb模型

1 meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。
2 checkpoint_dir目录下还有checkpoint文件,该文件是个文本文件,里面记录了保存的最新checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model。
3 data-00000-of-00001文件和.index文件就是ckpt文件,.data里面的内容存储的就是权重、偏置等内容。在TensorFlow0.11之前,使用ckpt一个后缀文件存储,以后的TensorFlow版本都是使用这两个文件共同存储模型参数。
4 index文件是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。

转换过程:

tensorflow: ckpt模型转成pb模型

1. 将我们的模型保存为ckpt模型,生成对应的文件
(1)获取saver对象实例化saver的时候,如果不对tf.train.Saver指定任何参数,默认会保存所有变量。如果你只想保存一部分变量,可以通过指定variables/collections。在创建tf.train.Saver实例时,通过将需要保存的变量构造list或者dictionary,传入到Saver中;而对saver对象其max_to_keep,keep_checkpoint_every_n_hours参数的设置来控制模型文件的生成个数和时间,也可设置global_step 来控制。
(2)先导出图结构:saver.export_meta_graph(model_path + meta_graph_name) , 注意此处代码执行一次,因为图结构是不变的,导出一次就行.
(3)导出权重参数等的模型的数据文件data: saver.save(sess, model_path + model_name, global_step=global_step, write_meta_graph=False)。

2. 将我们的ckpt模型改为pb模型
(1)先导入图结构,构造网络图:saver=tf.train.import_meta_graph(’./checkpoint_dir/MyModel-1000.meta’),即导入meta文件路径
(2)加载参数:saver.restore(sess,tf.train.latest_checkpoint(’./checkpoint_dir’)) , 光有图不行,咱还需要模型的参数呀,
而参数是变量,必须依靠Session,所以加载参数时,要先创建Session。
(3)获取我们刚刚构造好的图并序列化:graphdef = tf.get_default_graph().as_graph_def()
(4)把变量变为常量,固化我们的这张图:frozen_graph = tf.graph_util.convert_variables_to_constants(sess, graphdef, output_node_names),若output_node可以有多个,且应包含命名空间,即命名空间/output_node_names.
(5)使用tf.graph_util.remove_training_nodes(frozen_graph)将在训练阶段才使用的变量去除,也就是一些gradients。
(6)最后,通过tf.gfile.GFile打开一个指定文件f,将图反序列化写入该文件:f.write(graph_def.SerializeToString())

3.使用模型
(1)将保存的模型文件解析为GraphDef :graph_def.ParseFromString(gfile.FastGFile(“model.pb”,‘rb’).read()),这里分为两步,一:通过gfile.FastGFile(“model.pb”,‘rb’).read()获取保存的pb模型对象并读取文件,二: graph_def = tf.GraphDef()创建graph_def对象,然后 graph_def.ParseFromString()解析为二进制放进graph_def对象中。
(2)导入我们上一步创建的图为默认图,这里我们需要指定张量的名称而不是节点的名称: tf.import_graph_def(graph_def,return_elements=[“add:0”]) ,此处"add:0"为张量名称。
(3)现在我们就可以通过get_tensor_by_name方法来获取tensor,执行我们的pb模型了: sess.graph.get_tensor_by_name(“input:0”)

参考文章:
https://blog.csdn.net/zongza/article/details/88540652
https://blog.csdn.net/huachao1001/article/details/78501928
https://blog.csdn.net/lwplwf/article/details/62419087
https://blog.csdn.net/gzj_1101/article/details/80299610
https://www.jianshu.com/p/06548e3e8f4b
https://blog.csdn.net/guyuealian/article/details/82218092

相关文章:

  • 2022-12-23
  • 2022-12-23
  • 2019-11-28
  • 2022-12-23
  • 2022-02-04
  • 2022-01-29
  • 2021-12-31
  • 2022-12-23
猜你喜欢
  • 2019-07-04
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2022-12-23
  • 2022-01-03
  • 2022-12-23
相关资源
相似解决方案