【问题标题】:keras/tensorflow model: gradient w.r.t. input return the same (wrong?) value for all input datakeras/tensorflow 模型:梯度 w.r.t.输入为所有输入数据返回相同的(错误的?)值
【发布时间】:2019-01-16 18:41:22
【问题描述】:

给定一个训练有素的keras 模型,我正在尝试计算输出相对于输入的梯度。

此示例尝试将函数y=x^2 与由 4 层 relu 激活组成的 keras 模型拟合,并计算模型输出相对于输入的梯度。

from keras.models import Sequential
from keras.layers import Dense
from keras import backend as k
from sklearn.model_selection import train_test_split
import numpy as np
import tensorflow as tf

# random data
x = np.random.random((1000, 1))
y = x**2

# split train/val
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.15)

# model
model = Sequential()
# 1d input
model.add(Dense(10, input_shape=(1, ), activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(10, activation='relu'))
# 1d output
model.add(Dense(1))

## compile and fit
model.compile(loss='mse', optimizer='rmsprop', metrics=['mae'])
model.fit(x_train, y_train, batch_size=256, epochs=100, validation_data=(x_val, y_val), shuffle=True)

## compute derivative (gradient)
session = tf.Session()
session.run(tf.global_variables_initializer())
y_val_d_evaluated = session.run(tf.gradients(model.output, model.input), feed_dict={model.input: x_val})

print(y_val_d_evaluated)

x_val01之间的150个随机数的向量。

我的期望是y_val_d_evaluated(渐变)应该是:

A. array 包含 150 个不同的数字(因为 x_val 包含 150 个不同的数字);

B.这些值应该接近2*x_valx^2 的导数)。

相反,每次我运行此示例时,y_val_d_evaluated 包含 150 个相等的值(例如 [0.0150494][-0.0150494][0.0150494][-0.0150494],...),而且该值与 @ 非常不同987654337@,每次运行示例时值都会发生变化。

任何人有一些建议可以帮助我理解为什么这段代码没有给出预期的渐变结果?

【问题讨论】:

    标签: python tensorflow neural-network keras gradient


    【解决方案1】:

    好的,我发现了问题,以下几行:

    session = tf.Session()
    session.run(tf.global_variables_initializer())
    

    创建一个覆盖模型参数的新 tf 会话,因此在这些指令之后,模型是具有随机初始参数的模型。这就解释了为什么每次运行的值都不同。

    从 keras 环境中获取 tensorflow 会话的解决方案是使用:

    session = k.get_session()
    

    这个简单的改变结果如我所料。

    【讨论】:

      猜你喜欢
      • 2019-11-19
      • 1970-01-01
      • 1970-01-01
      • 2020-02-27
      • 2020-11-15
      • 2017-01-26
      • 2018-10-31
      • 1970-01-01
      相关资源
      最近更新 更多