如果我理解正确,这应该对你有用吗?
它使用tf.import_graph_def 来完成这项工作
我们有x,然后输入第一个图得到y = 2 *x,
然后我们将y 提供给第二个图表以获取b = 2 * y,对于x = 1.0,以下代码将生成4.0。
import tensorflow as tf
FLOAT = tf.float32
tf.reset_default_graph()
def graph_1():
g = tf.Graph()
with g.as_default():
x = tf.placeholder(FLOAT, [], name='x')
y = tf.multiply(2.0, x, name='y')
return g
def graph_2():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(FLOAT, [], name='a')
b = tf.multiply(2.0, a, name='b')
return g
# x = 1.0
x = tf.constant(1.0, FLOAT, [])
# feed x to graph_1 -> y = 2.0
g1 = graph_1()
[g1_y] = tf.import_graph_def(g1.as_graph_def(), input_map={'x': x}, return_elements=['y:0'])
# feed y to graph_2 -> b = 4.0
g2 = graph_2()
[g2_b] = tf.import_graph_def(g2.as_graph_def(), input_map={'a': g1_y}, return_elements=['b:0'])
with tf.Session() as sess:
print(sess.run([g2_b]))
笔记本:https://gist.github.com/phizaz/21a5454ddc6c2a15c5c0eae91c96cda5
顺便说一句,如果graph_1 或graph_2 包含“变量”,这将不起作用......到目前为止,我不知道如何初始化这些底层变量,有什么建议吗?