【问题标题】:TypeError when using tf.keras.optimizers.apply_gradients method in TensorFlow 2.0在 TensorFlow 2.0 中使用 tf.keras.optimizers.apply_gradients 方法时出现 TypeError
【发布时间】:2019-04-25 17:40:06
【问题描述】:

当我执行如下代码时,错误消息 TypeError: zip argument #2 must support iteration 弹出到屏幕上。

theta = tf.Variable(tf.zeros(100), dtype=tf.float32, name='theta')

@tf.function
def p(x):
    N = tf.cast(tf.shape(x)[0], tf.int64)
    softmax = tf.ones([N, 1]) * tf.math.softmax(theta)
    idx_x = tf.stack([tf.range(N, dtype=tf.int64), x-1], axis=1)
    return tf.gather_nd(softmax, idx_x)


@tf.function
def softmaxLoss(x):
    return tf.reduce_mean(-tf.math.log(p(x)))


train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
                                repeat(1).batch(BATCH_SIZE)


# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
val_loss_metric = tf.keras.metrics.Mean(name='val_loss')
optimizer = tf.keras.optimizers.Adam(0.001)

@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        log_loss = softmaxLoss(inputs)
    gradients = tape.gradient(log_loss,theta)
    optimizer.apply_gradients(zip(gradients, theta))
    # Update the metrics
    loss_metric.update_state(log_loss)


for epoch in range(NUM_EPOCHS):
    # Reset the metrics
    loss_metric.reset_states()

    # Shuffle dataset before each training epoch
    train_dset = train_dset.shuffle(buffer_size=10000)
    for inputs in train_dset:
        train_step(inputs)


经过检查,发现问题出在这行代码上:

optimizer.apply_gradients(zip(gradients, theta))

我该如何解决这个问题?

【问题讨论】:

    标签: python-3.x tensorflow tensorflow2.0 tf.keras


    【解决方案1】:

    您可以通过将 theta 设为列表来解决此问题,因为 zip 要求参数是可迭代的(并且单个 tf.Variable 是不可迭代的)。

    因此:

    optimizer.apply_gradients(zip(gradients, [theta]))
    

    【讨论】:

    • 然后出现了另一个错误消息:TypeError: Tensor 对象仅在启用急切执行时才可迭代。要迭代此张量,请使用 tf.map_fn。 如何使用 tf.map_fn 来解决此错误?
    • 我终于通过gradients = tape.gradient(log_loss,[theta])解决了这个问题,谢谢。
    猜你喜欢
    • 1970-01-01
    • 2016-12-06
    • 2014-03-14
    • 1970-01-01
    • 2020-04-22
    • 1970-01-01
    • 2017-12-18
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多