【发布时间】:2017-11-19 09:55:45
【问题描述】:
如何使用 c++ 在 Tensorflow 中保存模型?我在谷歌和百度上搜索过,但没有找到任何解决方案。然后我看了tensorflow的api文档,介绍的少了关于C++的介绍
【问题讨论】:
标签: c++ tensorflow model save
如何使用 c++ 在 Tensorflow 中保存模型?我在谷歌和百度上搜索过,但没有找到任何解决方案。然后我看了tensorflow的api文档,介绍的少了关于C++的介绍
【问题讨论】:
标签: c++ tensorflow model save
模型保存仅在 Python 中实现。目前无法使用 C++ API 保存模型。 C++ API 允许您加载和使用模型,而不是训练或保存它们。
【讨论】:
假设您对 tensorflow C++ API 有基本的了解,并且知道如何使用 C++ API 构建图形。您可以使用这两个功能:
tensorflow::WriteTextProto() :您可以从 tensorflow::Scope::ToGraphDef() 获得 tensorflow::GraphDef(代表您定义的所有操作,例如加、乘、均值 .... 等),将 tensorflow::GraphDef 保存到文本 protobuf 文件
tensorflow::checkpoint::TensorSliceWriter 将参数矩阵的当前状态保存到外部文件(检查点),有点复杂,但对我来说效果很好
首先,您必须通过调用tensorflow::Session::Run 来获取训练参数,这会将参数矩阵列表返回给output_tensor(参见下面的示例):
std::vector<tensorflow::Tensor> output_tensor;
tensorflow::Session::Run({}, {"name_of_param_mtx_1", "name_of_param_mtx_2",}, {}, &output_tensor);
上面的name_of_param_mtx_1 和name_of_param_mtx_2 应该是tensorflow::Variable 中的参数矩阵的名称,例如
auto name_of_param_mtx_1 = tensorflow::ops::Variable (root.WithOpName("name_of_param_mtx_1"), {7, 17}, tensorflow::DT_FLOAT);
那么你需要为tensorflow::checkpoint::TensorSliceWriter准备以下内容:
tensorflow::Tensor.tensor_data().data()获取参数原始数据的基地址
tensorflow::Tensor 的形状,通过调用tensorflow::Tensor::dim_size(NUM_DIMENSION)。例如 7x17 2D 参数矩阵,NUM_DIMENSION 可以是 0 和 1,其中 tensorflow::Tensor::dim_size(0) 为 7,tensorflow::Tensor::dim_size(1) 为 17。tensorflow::TensorSlice::ParseOrDie("-:-")创建tensorflow::TensorSlice,似乎tensorflow::TensorSlice::ParseOrDie的唯一参数将在内部分析,例如-:- 表示获取矩阵的所有项。如果用户只想要训练过的参数矩阵的一部分,例如只取所有行的第二列,那么字符串参数可能是 -:2 ,我还没有想出 tensorflow::TensorSlice::ParseOrDie 的这种高级用法。希望对您有所帮助。
【讨论】: