【发布时间】:2019-09-06 17:51:58
【问题描述】:
在对使用 saved_model API 时恢复tf.py_func() 进行研究后,我找不到tensorflow 中记录的其他信息:
该操作必须与调用
tf.py_func()的 Python 程序在同一地址空间中运行。如果您使用分布式 TensorFlow,则必须在与调用tf.py_func()的程序相同的进程中运行tf.train.Server,并且必须将创建的操作固定到该服务器中的设备(例如与tf.device()一起使用:)
两个 save/load sn-ps 有助于说明情况。
保存部分:
def wrapper(x, y):
with tf.name_scope('wrapper'):
return tf.py_func(Copy, [x, y], [tf.float32, tf.float32])
def Copy(x, y):
return x, y
x_ph = tf.placeholder(tf.float32, [None], 'x_ph')
y_ph = tf.placeholder(tf.float32, [None], 'y_ph')
with tf.name_scope('input'):
ds = tf.data.Dataset.from_tensor_slices((x_ph, y_ph))
ds = ds.map(wrapper)
ds = ds.batch(1)
it = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
it_init_op = it.make_initializer(ds, name='it_init_op')
x_it, y_it = it.get_next()
# Simple operation
with tf.name_scope('add'):
res = tf.add(x_it, y_it)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), it_init_op], feed_dict={y_ph: [10] * 10, x_ph: [i for i in range(10)]})
sess.run([res])
tf.saved_model.simple_save(sess, './dummy/test', {'x_ph': x_ph, 'y_ph': y_ph}, {'res': res})
加载部分:
graph = tf.Graph()
graph.as_default()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './dummy/test')
res = graph.get_tensor_by_name('add/Add:0')
it_init_op = graph.get_operation_by_name('input/it_init_op')
x_ph = graph.get_tensor_by_name('x_ph:0')
y_ph = graph.get_tensor_by_name('y_ph:0')
sess.run([it_init_op], feed_dict={x_ph: [5] * 5, y_ph: [i for i in range(5)]})
for _ in range(5):
sess.run([res])
错误:
ValueError: 找不到回调 pyfunc_0
众所周知,tf.py_func() 包裹的函数不会与模型一起保存。有没有人有解决方案来使用 tf doc 应用tf.train.Server 给出的小提示来恢复它
【问题讨论】:
-
Tensorflow 似乎没有解决方案,因为当我们保存模型时,pythonic 部分没有保存。要绕过这个问题,应该在执行
import_meta_graph()时修剪输入管道中的py_func(),并使用参数input_map='input_pipeline/IteratorGetNext':new_inputpipeline_with_pyfunc传递/剪切输入门节点仍然等待另一个更好的解决方案
标签: python-3.x tensorflow tensorflow-serving tensorflow-datasets