【问题标题】:Keras: calculating derivatives of model output wrt input returns [None]Keras:计算模型输出与输入返回的导数[无]
【发布时间】:2018-03-16 03:59:50
【问题描述】:

我需要帮助计算 Keras 中模型输出 wrt 输入的导数。

我想为损失函数添加一个正则化函数。正则化器包含分类器函数的导数。所以我试图取模型输出的导数。该模型是具有一个隐藏层的 MLP。数据集是 MNIST。当我编译模型并取导数时,我得到 [None] 作为结果而不是导数函数。

我看过一个类似的帖子,但也没有得到答案: Taking derivative of Keras model wrt to inputs is returning all zeros

这是我的代码。请帮我解决问题。

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras import backend as K

num_hiddenNodes = 1024
num_classes = 10

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28 * 28)
X_train = X_train.astype('float32')
X_train /= 255
y_train = keras.utils.to_categorical(y_train, num_classes)

model = Sequential()
model.add(Dense(num_hiddenNodes, activation='softplus', input_shape=(784,)))
model.add(Dense(num_classes, activation='softmax'))

# Compile the model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
logits = model.output
# logits = model.layers[-1].output
print(logits)
X = K.identity(X_train)
# X = tf.placeholder(dtype=tf.float32, shape=(None, 784))
print(X)
print(K.gradients(logits, X))

这是代码的输出。这两个参数是张量。 gradients 函数返回 None。

Tensor("dense_2/Softmax:0", shape=(?, 10), dtype=float32)
Tensor("Identity:0", shape=(60000, 784), dtype=float32)
[None]

【问题讨论】:

    标签: python tensorflow keras derivative


    【解决方案1】:

    您正在计算关于 X_train 的梯度,它不是计算图的输入变量。相反,您需要获取模型的符号输入张量,因此请尝试以下操作:

    grads = K.gradients(model.output, model.input)
    

    【讨论】:

    • 感谢您的回答。首先是一个问题,K.gradient(model.input, model.output) 在你的答案中应该是 K.gradient(model.output, model.input) 吗?然后,我尝试使用 model.input 作为参数。现在它可以返回 []。
    • 我想知道model.input、K.identity(X_train)和tf.placeholder(dtype=tf.float32, shape=(None, 784)有什么区别。它们都是张量与相同的形状:Tensor("dense_1_input:0", shape=(?, 784), dtype=float32), Tensor("Identity:0", shape=(60000, 784), dtype=float32), Tensor("Placeholder :0", shape=(?, 784), dtype=float32)。但是只有第一个可以用来获取渐变。
    • @user7367951 是的,参数已交换。 K.identity 没有连接到计算图,实际上你不能用它来做任何事情。
    • K.gradients 不是K.gradient
    • @quant 不,根本没有任何信息,细节,版本等,所以你的评论毫无用处。
    猜你喜欢
    • 1970-01-01
    • 2018-09-24
    • 1970-01-01
    • 1970-01-01
    • 2019-06-16
    • 2021-02-24
    • 1970-01-01
    • 2019-08-19
    • 2019-04-08
    相关资源
    最近更新 更多