【问题标题】:如果使用两个相同的参数调用函数,则 Tensorflow `tf.function` 会失败
【发布时间】:2022-01-22 08:15:44
【问题描述】:

在我的 TF 模型中,我的 call 函数调用外部能量函数,该函数依赖于单个参数被传递两次的函数(参见下面的简化版本):

import tensorflow as tf

@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
    E3 = 2.0
    return E3

@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
    E3 = 2.0
    return E3

@tf.function # without tf.function this works fine
def energy(coords, gamma):
    xyz_i = coords[0, 0 : 3]
    xyz_j = coords[0, 3 : 6]
    rij = xyz_j - xyz_i
    norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
    E3 = calc_sw3( gamma,gamma,norm_rij)    # repeating gamma gives error
    # E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
    return E3



class SWLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.gamma = tf.Variable(2.51412, dtype=tf.float32)

    def call(self, coords_all):
        total_conf_energy = energy( coords_all, self.gamma)
        return total_conf_energy
# =============================================================================


SWL = SWLayer()
coords2 = tf.constant([[
                        1.9434,  1.0817,  1.0803,  
                        2.6852,  2.7203,  1.0802,  
                        1.3807,  1.3573,  1.3307]])

with tf.GradientTape() as tape:
    tape.watch(coords2)
    E = SWL( coords2)

这里如果 gamma 只传递一次,或者我不使用 tf.function 装饰器。但是使用tf.function 并两次传递相同的变量,我得到以下错误:

Traceback (most recent call last):
  File "temp_tf.py", line 47, in <module>
    E = SWL( coords2)
  File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "temp_tf.py", line 34, in call
    total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).

in user code:

    File "temp_tf.py", line 22, in energy  *
        E3 = calc_sw3( gamma,gamma,norm_rij)    # repeating gamma gives error

    IndexError: list index out of range


Call arguments received:
  • coords_all=tf.Tensor(shape=(1, 9), dtype=float32)

这是预期的行为吗?

【问题讨论】:

    标签: tensorflow keras tensorflow2.0


    【解决方案1】:

    有趣的问题!我认为错误源于回溯,这导致 tf.function 不止一次评估energy 中的python sn-ps。请参阅此issue。此外,这可能与bug 有关。

    几个观察:

    1.从calc_sw3 中删除 tf.function 装饰器有效并且与docs 一致

    [...] tf.function 适用于一个函数及其调用的所有其他函数。

    因此,如果您再次将tf.function 明确应用于calc_sw3,您可能会触发回溯,但是您可能想知道为什么calc_sw3_noerr 有效?也就是说,它一定和变量gamma有关。

    2。将输入signatures 添加到energy 函数上方的tf.function,同时保持其余代码不变,也可以:

    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32)])
    def energy(coords, gamma):
        xyz_i = coords[0, 0 : 3]
        xyz_j = coords[0, 3 : 6]
        rij = xyz_j - xyz_i
        norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
    
        E3 = calc_sw3(gamma, gamma, norm_rij) 
        return E3
    

    这个方法:

    [...] 确保只创建一个 ConcreteFunction,并将 GenericFunction 限制为指定的形状和类型。当张量具有动态形状时,这是一种限制回溯的有效方法。

    所以也许假设gamma 每次都以不同的形状被调用,从而触发回溯(只是一个假设)。触发错误的事实实际上是有意或故意设计的,如here 所述。还有一个有趣的comment

    tf.functions 只能处理预定义的输入形状,如果形状发生变化,或者传递了不同的 python 对象,tensorflow 会自动重建函数

    最后,为什么我认为这是一个追踪问题?因为实际的错误来自这部分代码sn -p

    xyz_i = coords[0, 0 : 3]
    xyz_j = coords[0, 3 : 6]
    rij = xyz_j - xyz_i
    norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5
    

    您可以通过将其注释掉并将norm_rij 替换为某个值然后调用calc_sw3 来确认。它会起作用的。 这意味着这段代码 sn -p 可能被执行了不止一次,可能由于上面提到的原因。这也是有据可查的here

    在第一阶段,称为“跟踪”,Function 创建一个新的 tf.Graph。 Python 代码正常运行,但所有 TensorFlow 操作(例如添加两个 Tensor)都被延迟:它们被 tf.Graph 捕获而不运行。

    在第二阶段,运行一个包含第一阶段延迟的所有内容的 tf.Graph。这个阶段比追踪阶段快很多

    【讨论】:

    • 感谢您的详细回复,在您的回复之后,我注意到只有当gammatf.Variable 时才会发生这种情况,如果我让它tf.constant 它工作正常(可能这有助于指出问题?)。对于我的工作,您建议的解决方法就足够了,尽管我很好奇,如果tf.function 重新评估该函数,它不会使用相同的参数重新评估它吗?如果是,那么为什么会出现索引错误,以及究竟是哪个组件给出的?
    • 这也是一个关于tf.constant 的有趣观察。实际上很难说,但我假设norm_rij = (rij[0]**2 + rij[1]**2 + rij[2]**2)**0.5 导致了索引错误。诚然,这种行为确实很棘手。很高兴它有所帮助。
    • tf.constant 出现在我的脑海中,因为您提供的链接说每个 tf.Variable 对象都分配了一个 ID。所以我假设如果 calc_sw3 函数有两个具有不同 ID 的输入,但传递相同的 ID Variable 对象,则可能会产生冲突。但我在 TF 中相对新手,仍然在阅读文档。我使用 tf.function 代替 PyTorch 的 jit.script
    • 我还在 github 上提交了错误报告,一旦我在那里得到一些东西就会更新:github.com/tensorflow/tensorflow/issues/53494
    • 好的。也许还可以根据您的新见解对其进行编辑。
    猜你喜欢
    • 1970-01-01
    • 2014-03-28
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-04-16
    • 2014-03-08
    相关资源
    最近更新 更多