【问题标题】:How to get the horizontal and vertical gradient of the difference between y_true and y_pred?如何得到y_true和y_pred之差的水平和垂直梯度?
【发布时间】:2019-03-06 19:01:17
【问题描述】:

我想使用 Keras 定义一个自定义损失函数,其中包含 y_true 和 y_pred 之间差异的梯度。 我发现numpy.gradient 可以帮助我得到一个数组的梯度。 所以我的损失函数部分代码如下所示:

def loss(y_true, y_pred):
    d   = y_true - y_pred
    gradient_x = np.gradient(d, axis=0)
    gradient_y = np.gradient(d, axis=1)

但事实证明 d 是一个 Tensorflow 张量类,numpy.gradient 无法处理它。 我对 Keras 和 Tensorflow 有点陌生。

有没有其他功能可以帮助我做到这一点?还是我必须自己计算梯度?

【问题讨论】:

    标签: python tensorflow keras deep-learning loss-function


    【解决方案1】:

    Tensorflow 张量在执行时根本不是数组,它们只是对正在构建的计算图的引用。您可能想查看tutorial on how Tensorflow builds graphs

    您的损失函数有两个问题:首先,在任一轴上折叠都不会产生标量,因此无法进行导数;其次,Tensorflow 中似乎不存在np.gradient

    对于第一个问题,您可以通过沿gradient_ygradient_x 的剩余轴减少来解决。我不知道您可能想要使用哪个功能,因为我不了解您的应用程序。

    第二个问题可以通过两种方式解决:

    1. 您可以使用py_func 包装np.gradient,但您计划将其用作损失函数,因此您需要获取该函数的梯度,并且定义py_func 调用的梯度为@987654323 @。
    2. 使用纯 Tensorflow 编写您自己的 np.gradient 版本。

    例如,这里是张量流中的一维np.gradient未测试):

    def gradient(x):
        d = x[1:]-x[:-1]
        fd = tf.concat([x,x[-1]], 0).expand_dims(1)
        bd = tf.concat([x[0],x], 0).expand_dims(1)
        d = tf.concat([fd,bd], 1)
        return tf.reduce_mean(d,1)
    

    【讨论】:

      【解决方案2】:

      我遇到了想要定义损失函数的同样问题 np.gradient。我写了一个纯张量流版本的函数来解决这个问题。

      这是我的版本:(它的行为与 np.gradientaxis=-1 的行为相同。)如果你想让它适用于任意轴,你需要多尝试一下:

      def my_gradient_tf(a):
          rght = tf.concat((a[..., 1:], tf.expand_dims(a[..., -1], -1)), -1)
          left = tf.concat((tf.expand_dims(a[...,0], -1), a[..., :-1]), -1)
          ones = tf.ones_like(rght[..., 2:], tf.float64)
          one = tf.expand_dims(ones[...,0], -1)
          divi = tf.concat((one, ones*2, one), -1)
          return (rght-left) / divi
      
      

      【讨论】:

        猜你喜欢
        • 2012-09-30
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2021-08-16
        • 2011-08-04
        • 2021-01-22
        • 2013-12-10
        • 1970-01-01
        相关资源
        最近更新 更多