【问题标题】:tf.py_func() unexpected outputs in loop using lambda functiontf.py_func() 使用 lambda 函数在循环中出现意外输出
【发布时间】:2018-04-17 19:19:18
【问题描述】:

在实现相同的简单功能时,我观察到 Numpy 和纯 Tensorflow 的不同行为,该功能在 for 循环中与 tf.py_func 共享一个变量。

让我们从纯 Numpy 版本开始:

def my_func(x, k):
    return np.tile(x,k)

x = np.ones((1), np.int64)
for i in range(1,3):
    x = my_func(x, i)

print(x)

这会产生预期的输出。最初 x[1]。在第一次迭代中,它被复制一次以产生[1]。然后在下一次迭代中,结果被复制两次,产生最终输出[1 1]

类似的方法也可以在纯 Tensorflow 中产生相同的预期输出:

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.tile(x, [i])

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)

输出是[1 1]

现在我正在尝试使用tf.py_func 做同样的事情,但我无法理解为什么我会看到不同的输出。这段代码:

import tensorflow as tf
import numpy as np

def my_func(x, k):
    return np.tile(x,k)

x = tf.constant([1], tf.int64)
for i in range(1,3):
    x = tf.py_func(lambda y: my_func(y, i), [x], tf.int64)

with tf.Session() as sess:
    xx = sess.run(x)
    print(xx)

产生意想不到的结果[1 1 1 1]

为什么会这样? py_func 是否有一些属性不能很好地用于共享(张量)变量名称,在这种情况下,变量 x 在每次循环迭代时都会更新?

请注意,这是一个重现问题的简化示例,其功能很容易在纯 Tensorflow 中重现。在我的实际应用中需要使用tf.py_func,因为功能比较复杂。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    没有 lambda 函数,它可以按预期工作:

    import tensorflow as tf
    import numpy as np
    
    def my_func(x, k):
        return np.tile(x,k)
    
    x = tf.constant([1], tf.int64)
    for i in range(1,3):
        x = tf.py_func(my_func, [x, i], tf.int64)
    
    with tf.Session() as sess:
        xx = sess.run(x)
        print(xx)
    

    返回[1 1]

    编辑

    我发现了原因:lambda y: my_func(y, i) 按引用而不是按值保存 i。因此 for 循环的最后一个 i 值应用于循环中的所有 py_func。这是一个显示问题的更简单的示例:

    import tensorflow as tf
    
    def my_func(x, y):
      return x - y
    
    x1 = tf.constant([0], tf.float32)
    for i in range(2):
        x1 = tf.py_func(lambda y: my_func(y, i), [x1], tf.float32)
    
    x2 = tf.constant([0], tf.float32)
    x2 = tf.py_func(lambda y: my_func(y, 0), [x2], tf.float32)
    x2 = tf.py_func(lambda y: my_func(y, 1), [x2], tf.float32)
    
    with tf.Session() as sess:
        print(sess.run(x1))
        print(sess.run(x2))
    

    【讨论】:

    • 干得好!这个真的让我摸不着头脑。
    猜你喜欢
    • 2018-05-08
    • 2015-07-31
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-09-28
    • 2018-02-04
    • 2018-09-13
    相关资源
    最近更新 更多