【问题标题】:What is the use of a *.pb file in TensorFlow and how does it work?TensorFlow 中的 *.pb 文件有什么用,它是如何工作的?
【发布时间】:2018-12-19 01:34:54
【问题描述】:

我正在使用一些实现来创建使用this file 的面部识别:

"facenet.load_model("20170512-110547/20170512-110547.pb")"

这个文件有什么用?我不确定它是如何工作的。

控制台日志:

Model filename: 20170512-110547/20170512-110547.pb
distance = 0.72212267

代码实际拥有者的 Github 链接 https://github.com/arunmandal53/facematch

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    pb 代表 protobuf。在 TensorFlow 中,protbuf 文件包含图形定义以及模型的权重。因此,只需一个 pb 文件即可运行给定的训练模型。

    给定一个pb 文件,您可以按如下方式加载它。

    def load_pb(path_to_pb):
        with tf.gfile.GFile(path_to_pb, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')
            return graph
    

    加载图表后,您基本上可以做任何事情。例如,您可以使用

    检索感兴趣的张量
    input = graph.get_tensor_by_name('input:0')
    output = graph.get_tensor_by_name('output:0')
    

    并使用常规的 TensorFlow 例程,例如:

    sess.run(output, feed_dict={input: some_data})
    

    【讨论】:

    • 从我用外行的语言理解的地方,它是某种图形 API。
    • 这只是一种将模型(例如神经网络)保存到磁盘以供以后恢复/重用的方法。
    • +1 以获取如何加载它的示例。一个小补充:如果 pb 文件不是二进制文件,我相信您需要使用“google.protobuf.text_format.Merge(f.read(), graph_def)”代替“graph_def.ParseFromString(f.read()) ",见tensorflow.org/guide/extend/model_files
    • 对于 TensorFlow v2.x,此方法已被弃用,但如果我们将 .compat.v1 中缀添加到 gfileGraphDef 名称,它仍然有效。例如,第一行变为:with tf.compat.v1.gfile.GFile(path_to_pb, "rb") as f
    【解决方案2】:

    说明

    .pb 格式是protocol buffer (protobuf) 格式,在 Tensorflow 中,这种格式用于保存模型。 Protobufs 是 Google 存储数据的一种通用方式,它更易于传输,因为它更有效地压缩数据并强制数据结构。在 TensorFlow 中使用时,它被称为 SavedModel 协议缓冲区,这是保存 Keras/Tensorflow 2.0 模型时的默认格式。有关此格式的更多信息,请访问 herehere

    例如,以下代码(特别是m.save)将创建一个名为my_new_model 的文件夹,并在其中保存saved_model.pbassets/ 文件夹和variables/ 文件夹。

    # first download a SavedModel from TFHub.dev, a website with models
    m = tf.keras.Sequential([
        hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4")
    ])
    m.build([None, 224, 224, 3])  # Batch input shape.
    m.save("my_new_model") # defaults to save as SavedModel in tensorflow 2
    

    在某些地方,您可能还会看到 .h5 模型,这是 TF 1.X 的默认格式。 source


    额外信息:在 TensorFlow Lite 中,用于在移动和物联网设备上运行模型的库,而不是协议缓冲区,使用平面缓冲区。这是 TensorFlow Lite 转换器转换成的内容(.tflite 格式)。这是另一种非常有效的 Google 格式:它允许访问消息的任何部分而无需反序列化(与 json、xml 不同)。对于内存 (RAM) 较少的设备,从模型文件中加载您需要的内容更有意义,而不是将整个内容加载到内存中进行反序列化。


    在 TensorFlow 2 中加载 SavedModels

    我注意到 BiBi 显示加载模型的答案很受欢迎,并且在 TF2 中有更短的方法可以做到这一点:

    import tensorflow as tf
    model_path = "/path/to/directory/inception_v1_224_quant_20181026"
    model = tf.saved_model.load(model_path)
    

    注意,

    • 目录(即inception_v1_224_quant_20181026)必须有saved_model.pbsaved_model.pbtxt,否则代码会崩溃。您不能指定.pb 路径,请指定目录
    • 对于旧型号,您可能会收到 TypeError: 'AutoTrackable' object is not callablefix here

    如果您加载 TF1 模型,我发现我没有收到任何错误,但加载的文件没有按预期运行。 (例如它没有任何功能,如预测)

    【讨论】:

      猜你喜欢
      • 2010-11-08
      • 2010-09-12
      • 2019-06-01
      • 2016-05-28
      • 2017-11-14
      • 1970-01-01
      • 2016-08-19
      相关资源
      最近更新 更多