【发布时间】:2018-10-11 22:40:14
【问题描述】:
我正在为 deep-q 网络实现优先体验回放,规范的一部分是将梯度乘以所谓的重要性采样 (IS) 权重。梯度修改在以下论文的第 3.4 节中讨论:https://arxiv.org/pdf/1511.05952.pdf 我正在努力创建一个自定义损失函数,该函数除了y_true 和y_pred 之外还接受一系列 IS 权重。
这是我的模型的简化版本:
import numpy as np
import tensorflow as tf
# Input is RAM, each byte in the range of [0, 255].
in_obs = tf.keras.layers.Input(shape=(4,))
# Normalize the observation to the range of [0, 1].
norm = tf.keras.layers.Lambda(lambda x: x / 255.0)(in_obs)
# Hidden layers.
dense1 = tf.keras.layers.Dense(128, activation="relu")(norm)
dense2 = tf.keras.layers.Dense(128, activation="relu")(dense1)
dense3 = tf.keras.layers.Dense(128, activation="relu")(dense2)
dense4 = tf.keras.layers.Dense(128, activation="relu")(dense3)
# Output prediction, which is an action to take.
out_pred = tf.keras.layers.Dense(2, activation="linear")(dense4)
opt = tf.keras.optimizers.Adam(lr=5e-5)
network = tf.keras.models.Model(inputs=in_obs, outputs=out_pred)
network.compile(optimizer=opt, loss=huber_loss_mean_weighted)
这是我的自定义损失函数,它只是 Huber 损失乘以 IS 权重的实现:
'''
' Huber loss: https://en.wikipedia.org/wiki/Huber_loss
'''
def huber_loss(y_true, y_pred):
error = y_true - y_pred
cond = tf.keras.backend.abs(error) < 1.0
squared_loss = 0.5 * tf.keras.backend.square(error)
linear_loss = tf.keras.backend.abs(error) - 0.5
return tf.where(cond, squared_loss, linear_loss)
'''
' Importance Sampling weighted huber loss.
'''
def huber_loss_mean_weighted(y_true, y_pred, is_weights):
error = huber_loss(y_true, y_pred)
return tf.keras.backend.mean(error * is_weights)
重要的是is_weights 是动态的,即每次调用fit() 时都不同。因此,我不能简单地关闭is_weights,如下所述:Make a custom loss function in keras
我在网上找到了这段代码,它似乎使用Lambda 层来计算损失:https://github.com/keras-team/keras/blob/master/examples/image_ocr.py#L475 看起来很有希望,但我很难理解它/使其适应我的特定问题。任何帮助表示赞赏。
【问题讨论】:
-
可以将
is_weights视为网络的输入变量吗?如果是这样,您可以通过model.add_loss( huber_loss_mean_weightd( y_true, y_pred, is_weight) ) -
@user36624 当然,
is_weights可以被视为输入变量。使用add_loss似乎是一个干净的解决方案,但我不知道如何使用它。例如,在您的代码 sn-p 中,y_true和y_pred来自哪里?y_true是否对应于我的代码中的out_pred?在我add_loss之后,我使用什么作为loss编译参数?
标签: tensorflow keras