【问题标题】:Connecting two graphs连接两个图
【发布时间】:2016-05-12 20:57:14
【问题描述】:

假设我有两个不同的图表: 第一个包含 x 和 y: x = tf.placeholder(tf.float32, shape=(1)), y = 2*x, 第二个包含 a 和 b: a = tf.placeholder(tf.float32, shape=(1)), b = 2*x。

现在,我想通过在 y 和 a 之间添加一些“身份链接”来连接这两个图。换句话说,我想告诉第二个图从第一个图(y)中的某个节点获取其输入(a)。在您没有重新创建第二个图形的代码的情况下,它很方便,您只是从某个地方反序列化它。一种方法是使用 Session.run 计算第一个图的输出,然后将其提供给计算第二个图的输出的 Session.run 调用,但必须有一些干净的方法来做到这一点。

谢谢!

【问题讨论】:

  • 这有什么成功吗?

标签: tensorflow


【解决方案1】:

如果我理解正确,这应该对你有用吗?

它使用tf.import_graph_def 来完成这项工作

我们有x,然后输入第一个图得到y = 2 *x, 然后我们将y 提供给第二个图表以获取b = 2 * y,对于x = 1.0,以下代码将生成4.0

import tensorflow as tf
FLOAT = tf.float32
tf.reset_default_graph()

def graph_1():
    g = tf.Graph()
    with g.as_default():
        x = tf.placeholder(FLOAT, [], name='x')
        y = tf.multiply(2.0, x, name='y')
    return g

def graph_2():
    g = tf.Graph()
    with g.as_default():
        a = tf.placeholder(FLOAT, [], name='a')
        b = tf.multiply(2.0, a, name='b')
    return g

# x = 1.0
x = tf.constant(1.0, FLOAT, [])
# feed x to graph_1 -> y = 2.0
g1 = graph_1()
[g1_y] = tf.import_graph_def(g1.as_graph_def(), input_map={'x': x}, return_elements=['y:0'])
# feed y to graph_2 -> b = 4.0
g2 = graph_2()
[g2_b] = tf.import_graph_def(g2.as_graph_def(), input_map={'a': g1_y}, return_elements=['b:0'])

with tf.Session() as sess:
    print(sess.run([g2_b]))

笔记本:https://gist.github.com/phizaz/21a5454ddc6c2a15c5c0eae91c96cda5

顺便说一句,如果graph_1graph_2 包含“变量”,这将不起作用......到目前为止,我不知道如何初始化这些底层变量,有什么建议吗?

【讨论】:

    猜你喜欢
    • 2018-01-14
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-04-09
    • 2013-12-23
    • 2014-07-23
    • 2013-01-11
    • 2018-08-01
    相关资源
    最近更新 更多