【问题标题】:Using GradientTape to compute gradients of predictions with respect to some tensors使用 GradientTape 计算关于某些张量的预测梯度
【发布时间】:2020-03-20 23:52:31
【问题描述】:

我正在尝试在 TensorFlow 2.0 中使用 GP 实现 WGAN。要计算梯度惩罚,您需要计算预测相对于输入图像的梯度。

现在,为了使其更易于处理,它不是计算关于所有输入图像的预测梯度,而是沿原始数据点和假数据点的线计算插值数据点,并将其用作输入。

为了实现这一点,我首先开发了compute_gradients 函数,该函数将进行一些预测并返回相对于某些输入图像的梯度。首先,我想用tf.keras.backend.gradients 来做这件事,但它不会在急切模式下工作。所以,我现在正在尝试使用GradientTape 来执行此操作。

这是我用来测试的代码:

from tensorflow.keras import backend as K
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow as tf
import numpy as np

# Comes from Generative Deep Learning by David Foster
class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
    """Provides a (random) weighted average between real and generated image samples"""
    def call(self, inputs):
        alpha = K.random_uniform((self.batch_size, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

# Dummy critic
def make_critic():
    critic = Sequential()
    inputShape = (28, 28, 1)

    critic.add(Conv2D(32, (5, 5), padding="same", strides=(2, 2),
        input_shape=inputShape))
    critic.add(LeakyReLU(alpha=0.2))

    critic.add(Conv2D(64, (5, 5), padding="same", strides=(2, 2)))
    critic.add(LeakyReLU(alpha=0.2))

    critic.add(Flatten())
    critic.add(Dense(512))
    critic.add(LeakyReLU(alpha=0.2))
    critic.add(Dropout(0.3))
    critic.add(Dense(1))

    return critic

# Gather dataset
((X_train, _), (X_test, _)) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

# Note that I am using test images as fake images for testing purposes
interpolated_img = RandomWeightedAverage(32)([X_train[0:32].astype("float"), X_test[32:64].astype("float")])

# Compute gradients of the predictions with respect to the interpolated images
critic = make_critic()
with tf.GradientTape() as tape:
    y_pred = critic(interpolated_img)
gradients = tape.gradient(y_pred, interpolated_img)

渐变将变为None。我在这里错过了什么吗?

【问题讨论】:

  • 你能检查一下 interpolated_img 的值吗?都是0吗?
  • 它是一种类型还是您实际上是在上下文管理器之外调用tape.gradient

标签: python tensorflow machine-learning keras deep-learning


【解决方案1】:

关于一些张量的预测梯度......我在这里遗漏了什么吗?

是的。你需要一个tape.watch(interpolated_img):

with tf.GradientTape() as tape:
    tape.watch(interpolated_img)
    y_pred = critic(interpolated_img)

GradientTape 需要存储前向传播的中间值来计算梯度。通常,您需要渐变 WRT 变量。所以它不会保留从张量开始的计算痕迹,可能是为了节省内存。

如果你想要一个梯度WRT一个张量,你需要明确告诉tape

【讨论】:

    猜你喜欢
    • 2020-09-19
    • 1970-01-01
    • 2021-08-30
    • 1970-01-01
    • 2020-01-09
    • 1970-01-01
    • 1970-01-01
    • 2022-01-17
    • 1970-01-01
    相关资源
    最近更新 更多