【问题标题】:How to optimize this function calculating the categorical crossentropy of two numpy arrays如何优化此函数计算两个 numpy 数组的分类交叉熵
【发布时间】:2018-03-22 14:08:20
【问题描述】:

我想计算两个 numpy 数组的分类交叉熵。两个数组的长度相同。

  1. y_true 包含大约 10000 个二维数组,它们是标签
  2. y_pred 包含 10000 个二维数组,这是我的预测

结果应该是一个 1D numpy 数组,其中包含数组的所有分类交叉熵值。公式为:

这里x_true是一个真实向量的第i个元素,x_pred是预测向量的第i个元素。

我的实现看起来像这样,但速度很慢。完成整形以将 2D 数组转换为 1D 数组以对它们进行简单的迭代。

def categorical_cross_entropy(y_true, y_pred):
    losses = np.zeros(len(y_true))
    for i in range(len(y_true)):
        single_sequence = y_true[i].reshape(y_true.shape[1]*y_true.shape[2])
        single_pred = y_pred[i].reshape(y_pred.shape[1]*y_pred.shape[2])
        sum = 0
        for j in range(len(single_sequence)):
            log = math.log(single_pred[j])
            sum = sum + single_sequence[j] * log
        sum = sum * (-1)
        losses[i] = sum
    return losses

不可能转换为张量,因为tf.constant(y_pred)MemoryError 中失败,因为y_truey_pred 中的每个二维数组的尺寸大致为190 x 190。有什么想法吗?

【问题讨论】:

    标签: python arrays numpy tensorflow


    【解决方案1】:

    您可以使用scipy.special.xlogy。例如,

    In [10]: import numpy as np
    
    In [11]: from scipy.special import xlogy
    

    创建一些数据:

    In [12]: y_true = np.random.randint(1, 10, size=(8, 200, 200))
    
    In [13]: y_pred = np.random.randint(1, 10, size=(8, 200, 200))
    

    使用xlogy计算结果:

    In [14]: -xlogy(y_true, y_pred).sum(axis=(1, 2))
    Out[14]: 
    array([-283574.67634307, -283388.18672431, -284720.65206688,
           -285517.06983709, -286383.26148469, -282200.33634505,
           -285781.78641736, -285862.91148953])
    

    通过使用您的函数计算结果来验证结果:

    In [15]: categorical_cross_entropy(y_true, y_pred)
    Out[15]: 
    array([-283574.67634309, -283388.18672432, -284720.65206689,
           -285517.0698371 , -286383.2614847 , -282200.33634506,
           -285781.78641737, -285862.91148954])
    

    如果你不想依赖scipy,你可以用np.log做同样的事情,但如果y_pred中的任何值是0,你可能会收到警告:

    In [20]: -(y_true*np.log(y_pred)).sum(axis=(1, 2))
    Out[20]: 
    array([-283574.67634307, -283388.18672431, -284720.65206688,
           -285517.06983709, -286383.26148469, -282200.33634505,
           -285781.78641736, -285862.91148953])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-09-19
      • 1970-01-01
      • 2020-09-28
      • 2022-01-05
      • 1970-01-01
      • 1970-01-01
      • 2021-06-25
      • 2020-01-03
      相关资源
      最近更新 更多