【问题标题】:TensorFlow 2 tf.function decoratorTensorFlow 2 tf.function 装饰器
【发布时间】:2020-02-04 14:23:05
【问题描述】:

我有 TensorFlow 2.0 和 Python 3.7.5。

我编写了以下代码来执行小批量梯度下降:

@tf.function
def train_one_step(model, mask_model, optimizer, x, y):
    '''
    Function to compute one step of gradient descent optimization
    '''
    with tf.GradientTape() as tape:
        # Make predictions using defined model-
        y_pred = model(x)

        # Compute loss-
        loss = loss_fn(y, y_pred)

    # Compute gradients wrt defined loss and weights and biases-
    grads = tape.gradient(loss, model.trainable_variables)

    # type(grads)
    # list

    # List to hold element-wise multiplication between-
    # computed gradient and masks-
    grad_mask_mul = []

    # Perform element-wise multiplication between computed gradients and masks-
    for grad_layer, mask in zip(grads, mask_model.trainable_weights):
        grad_mask_mul.append(tf.math.multiply(grad_layer, mask))

    # Apply computed gradients to model's weights and biases-
    optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))

    # Compute accuracy-
    train_loss(loss)
    train_accuracy(y, y_pred)

    return None

在代码中,“mask_model”是一个掩码,可以为 0 或 1。“mask_model”的用途是控制训练哪些参数(因为,0 * 梯度下降 = 0)。

我的问题是,我在“train_one_step()”TensorFlow 装饰函数中使用“grad_mask_mul”列表变量。这会导致任何问题,例如:

ValueError: tf.function-decorated 函数试图创建变量 非第一次通话。

或者你们看到在 tensorflow 修饰函数中使用列表变量有什么问题吗?

谢谢!

【问题讨论】:

    标签: python python-3.x tensorflow tensorflow2.0


    【解决方案1】:

    这是 TensorFlow 2 中的一个错误。您可以在此处阅读更多信息 TF2 bug

    【讨论】:

      【解决方案2】:

      以防人们仍然收到错误

      ValueError: tf.function-decorated 函数试图创建变量 非第一次通话。

      但不确定发生了什么。 TensorFlow 团队在 2021 年 2 月左右更新了“函数”指南(见 https://github.com/tensorflow/tensorflow/issues/36574):

      查看更新的指南,尤其是“创建 tf 变量”部分:https://www.tensorflow.org/guide/function#creating_tfvariables

      基本上,OP 需要确保的是:

      • 在对函数化 train_step 的第一次调用中,所有 tf.Variable 都被创建一次,并且不会在任何模型或优化器中创建新的 tf.Variables在随后(即非第一次)调用train_one_step

      很可能,您已将 modelmask_model 的新未构建版本发送到 train_one_step,并且 tensorflow 正在尝试构建它(即新建 tf.Variable),但 train_one_step 已被调用之前是tf.function

      当前(更新的)指南解释了如何解决此类问题。

      【讨论】:

        猜你喜欢
        • 2021-12-01
        • 2019-09-17
        • 1970-01-01
        • 1970-01-01
        • 2019-08-04
        • 2020-08-07
        • 1970-01-01
        • 2020-09-16
        • 1970-01-01
        相关资源
        最近更新 更多