【问题标题】:How to save a model in tensorflow by using c++如何使用 c++ 在 tensorflow 中保存模型
【发布时间】:2017-11-19 09:55:45
【问题描述】:

如何使用 c++ 在 Tensorflow 中保存模型?我在谷歌和百度上搜索过,但没有找到任何解决方案。然后我看了tensorflow的api文档,介绍的少了关于C++的介绍

【问题讨论】:

    标签: c++ tensorflow model save


    【解决方案1】:

    模型保存仅在 Python 中实现。目前无法使用 C++ API 保存模型。 C++ API 允许您加载和使用模型,而不是训练或保存它们。

    【讨论】:

      【解决方案2】:

      假设您对 tensorflow C++ API 有基本的了解,并且知道如何使用 C++ API 构建图形。您可以使用这两个功能:

      1. tensorflow::WriteTextProto() :您可以从 tensorflow::Scope::ToGraphDef() 获得 tensorflow::GraphDef(代表您定义的所有操作,例如加、乘、均值 .... 等),将 tensorflow::GraphDef 保存到文本 protobuf 文件

      2. tensorflow::checkpoint::TensorSliceWriter 将参数矩阵的当前状态保存到外部文件(检查点),有点复杂,但对我来说效果很好

      首先,您必须通过调用tensorflow::Session::Run 来获取训练参数,这会将参数矩阵列表返回给output_tensor(参见下面的示例):

      std::vector<tensorflow::Tensor> output_tensor; 
      tensorflow::Session::Run({}, {"name_of_param_mtx_1", "name_of_param_mtx_2",}, {}, &output_tensor);
      

      上面的name_of_param_mtx_1name_of_param_mtx_2 应该是tensorflow::Variable 中的参数矩阵的名称,例如

      auto name_of_param_mtx_1 = tensorflow::ops::Variable (root.WithOpName("name_of_param_mtx_1"), {7, 17}, tensorflow::DT_FLOAT);
      

      那么你需要为tensorflow::checkpoint::TensorSliceWriter准备以下内容:

      • 通过调用tensorflow::Tensor.tensor_data().data()获取参数原始数据的基地址
      • 每个tensorflow::Tensor 的形状,通过调用tensorflow::Tensor::dim_size(NUM_DIMENSION)。例如 7x17 2D 参数矩阵,NUM_DIMENSION 可以是 0 和 1,其中 tensorflow::Tensor::dim_size(0) 为 7,tensorflow::Tensor::dim_size(1) 为 17。
      • 此检查点的名称,该名称必须与一个文件中的其他检查点不同
      • 通过调用tensorflow::TensorSlice::ParseOrDie("-:-")创建tensorflow::TensorSlice,似乎tensorflow::TensorSlice::ParseOrDie的唯一参数将在内部分析,例如-:- 表示获取矩阵的所有项。如果用户只想要训练过的参数矩阵的一部分,例如只取所有行的第二列,那么字符串参数可能是 -:2 ,我还没有想出 tensorflow::TensorSlice::ParseOrDie 的这种高级用法。

      希望对您有所帮助。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2017-11-09
        • 1970-01-01
        • 2023-02-09
        • 2019-08-30
        • 1970-01-01
        • 1970-01-01
        • 2021-05-06
        • 1970-01-01
        相关资源
        最近更新 更多