【问题标题】:C++ equivalent of train.import_meta_graph clear_devices argument?C++ 等效于 train.import_meta_graph clear_devices 参数?
【发布时间】:2017-02-16 04:19:52
【问题描述】:

我正在尝试使用 python 在 GPU 上训练图形,以从 C++ 进程加载图形。

status = ReadBinaryProto(Env::Default(), "model.pb", &graph_def);
session->Create(graph_def);

然后我收到错误消息

“无法将设备分配给节点...因为在此过程中没有注册符合该规范的设备;可用设备:/job:localhost/replica:0/task:0/cpu:0”

对于 python train.import_meta_graph API 有 clear_devices 参数,但它在 C++ API 上的等价物是什么?

为了加载图表,我在 Windows 上使用 Tensorflor,使用 CMake 和 -Dtensorflow_ENABLE_GPU=ON 构建,所以我的 vcxproj 有 GOOGLE_CUDA 定义。

我已阅读 Tensorflow, restore variables in a specific device,但它仅适用于 python API。

【问题讨论】:

    标签: c++ machine-learning tensorflow


    【解决方案1】:

    鉴于您还是要从 Python 导出图表,也许您可​​以在清除设备的情况下导出图表?比如:

    meta_graph = tf.train.export_meta_graph()
    with tf.Graph().as_default():
        tf.train.import_meta_graph(meta_graph, clear_devices=True)
        # Export the GraphDef now
        with open('/tmp/model.pb', 'w') as f:
            f.write(tf.get_default_graph().as_graph_def().SerializeToString())
    

    或者,您可以通过清除图中每个节点的 device 字段来复制 C++ 中 clear_devices=True 的行为。比如:

    status = ReadBinaryProto(Env::Default(), "model.pb", &graph_def);
    for (int n = 0; n < graph_def.node_size(); ++n) {
      graph_def.mutable_node(n)->clear_device();
    }
    session->Create(graph_def);
    

    但我建议不要这样做,因为它依赖于框架如何使用 GraphDefs 的内部细节,这可能很脆弱。

    【讨论】:

    • 因为我在 python 上使用的是 train.Saver(),而不是 export_meta_graph(),我在 C++ 上清除了设备,并且 Session::Create() 使用该图成功!谢谢!
    猜你喜欢
    • 1970-01-01
    • 2017-08-31
    • 1970-01-01
    • 2013-12-28
    • 2022-08-22
    • 1970-01-01
    • 2017-08-02
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多