【问题标题】:Keras MSE definitionKeras MSE 定义
【发布时间】:2018-02-05 08:37:30
【问题描述】:

我在 Keras 中偶然发现了 mse 的定义,但似乎找不到解释。

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

我希望在批次之间取平均值,即axis=0,但实际上是axis=-1

我也尝试了一下,看看K.mean 的行为是否真的像numpy.mean。 我一定是误会了什么。有人可以澄清一下吗?

我实际上无法在运行时查看成本函数的内部,对吧? 据我所知,该函数是在编译时调用的,这使我无法评估具体值。

我的意思是...想象一下进行回归并有一个输出神经元并以 10 的批大小进行训练。

>>> import numpy as np
>>> a = np.ones((10, 1))
>>> a
array([[ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.]])
>>> np.mean(a, axis=-1)
array([ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.])

它所做的只是将数组展平,而不是取所有预测的平均值。

【问题讨论】:

    标签: python machine-learning neural-network deep-learning keras


    【解决方案1】:

    K.mean(a, axis=-1)np.mean(a, axis=-1) 只是在最终维度上取平均值。这里a 是一个形状为(10, 1) 的数组,在这种情况下,在最终维度上取平均值恰好与将其展平为形状为(10,) 的一维数组相同。像这样实现它支持更一般的情况,例如多元线性回归。

    此外,您可以在运行时使用keras.backend.print_tensor 检查计算图中节点的值。见答案:Is there any way to debug a value inside a tensor while training on Keras?

    编辑:您的问题似乎是关于为什么损失不返回单个标量值,而是为批处理中的每个数据点返回一个标量值。为了支持样本加权,Keras 损失预计会为批次中的每个数据点返回一个标量。有关详细信息,请参阅 losses documentationfitsample_weight 参数。特别注意:“实际优化的目标是输出数组在所有数据点上的 [加权] 平均值。”

    【讨论】:

    • 我知道它做了它该做的。我的问题是:为什么要这样做?第一个维度是批量大小......那么为什么它不采用axis = 0的平均值。
    【解决方案2】:

    代码如下:

     def mean_squared_error(y_true, y_pred):
         return K.mean(K.square(y_pred - y_true), axis=-1)
    

    选择轴为-1的一个应用程序是,例如,对于彩色图片,它有3层RGB。每层的大小为 512 乘以 512 像素,它们存储在大小为 512 乘 512 乘 3 的对象中。

    假设您的任务涉及重建图片,并且您存储在另一个大小为 512 乘以 512 乘以 3 的对象中。

    调用 MSE 将使您能够分析每个像素的重建任务有多好。输出的大小为 512 乘以 512,总结了每个像素的表现。

    【讨论】:

      【解决方案3】:

      我和你有同样的问题。在我做了一些实验之后,我认为返回标量或张量作为损失并不重要,Keras(tensorflow)框架可以自动处理它。例如,如果您应用 K.tf.reduce_mean() 来获得标量而不是向量,则框架只需再增加一步来计算 reduce_mean() 的梯度。基于梯度链式法则,结果不会受到影响。

      【讨论】:

        猜你喜欢
        • 2018-02-24
        • 2021-07-16
        • 1970-01-01
        • 2019-06-25
        • 2020-04-05
        • 2021-01-24
        相关资源
        最近更新 更多