【问题标题】:How to restore dangling tf.py_func within the tf.data.Dataset() with tf.saved_model API?如何使用 tf.saved_model API 在 tf.data.Dataset() 中恢复悬空的 tf.py_func?
【发布时间】:2019-09-06 17:51:58
【问题描述】:

在对使用 saved_model API 时恢复tf.py_func() 进行研究后,我找不到tensorflow 中记录的其他信息:

该操作必须与调用 tf.py_func() 的 Python 程序在同一地址空间中运行。如果您使用分布式 TensorFlow,则必须在与调用 tf.py_func() 的程序相同的进程中运行 tf.train.Server,并且必须将创建的操作固定到该服务器中的设备(例如与 tf.device() 一起使用:)

两个 save/load sn-ps 有助于说明情况。

保存部分:

def wrapper(x, y):
    with tf.name_scope('wrapper'):
        return tf.py_func(Copy, [x, y], [tf.float32, tf.float32])

def Copy(x, y):
    return x, y

x_ph = tf.placeholder(tf.float32, [None], 'x_ph')
y_ph = tf.placeholder(tf.float32, [None], 'y_ph')

with tf.name_scope('input'):
    ds = tf.data.Dataset.from_tensor_slices((x_ph, y_ph))
    ds = ds.map(wrapper)
    ds = ds.batch(1)
    it = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
    it_init_op = it.make_initializer(ds, name='it_init_op')

x_it, y_it = it.get_next()

# Simple operation
with tf.name_scope('add'):
    res = tf.add(x_it, y_it)

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), it_init_op], feed_dict={y_ph: [10] * 10, x_ph: [i for i in range(10)]})
    sess.run([res])
    tf.saved_model.simple_save(sess, './dummy/test', {'x_ph': x_ph, 'y_ph': y_ph}, {'res': res})

加载部分:

graph = tf.Graph()
graph.as_default()
with tf.Session(graph=graph) as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './dummy/test')

    res = graph.get_tensor_by_name('add/Add:0')
    it_init_op = graph.get_operation_by_name('input/it_init_op')
    x_ph = graph.get_tensor_by_name('x_ph:0')
    y_ph = graph.get_tensor_by_name('y_ph:0')
    sess.run([it_init_op], feed_dict={x_ph: [5] * 5, y_ph: [i for i in range(5)]})

    for _ in range(5):
        sess.run([res])

错误:

ValueError: 找不到回调 pyfunc_0

众所周知,tf.py_func() 包裹的函数不会与模型一起保存。有没有人有解决方案来使用 tf doc 应用tf.train.Server 给出的小提示来恢复它

【问题讨论】:

  • Tensorflow 似乎没有解决方案,因为当我们保存模型时,pythonic 部分没有保存。要绕过这个问题,应该在执行import_meta_graph() 时修剪输入管道中的py_func(),并使用参数input_map='input_pipeline/IteratorGetNext':new_inputpipeline_with_pyfunc 传递/剪切输入门节点仍然等待另一个更好的解决方案

标签: python-3.x tensorflow tensorflow-serving tensorflow-datasets


【解决方案1】:

只要没有答案,我会建议我的,它轮廓 pb 而不是解决它。苦苦挣扎了半天,终于把它剪掉了。然后用更简单的占位符将新的输入/输出移植到它。此外,此 py_func 在 TF2.0 中已弃用

【讨论】:

    猜你喜欢
    • 2012-03-22
    • 2019-08-20
    • 1970-01-01
    • 2020-08-28
    • 2020-06-24
    • 1970-01-01
    • 2020-03-04
    • 2014-01-17
    • 2020-03-27
    相关资源
    最近更新 更多