【发布时间】:2021-10-13 12:05:28
【问题描述】:
我正在尝试制作一个计算 MSE 的自定义损失函数,但忽略所有真实值低于某个阈值(接近 0)的点。我可以通过以下方式使用 numpy 数组来实现这一点。
import numpy as np
a = np.random.normal(size=(4,4))
b = np.random.normal(size=(4,4))
temp_a = a[np.where(a>0.5)] # Your threshold condition
temp_b = b[np.where(a>0.5)]
mse = mean_squared_error(temp_a, temp_b)
但我不知道如何使用 keras 后端来做到这一点。我的自定义损失函数不起作用,因为 numpy 无法对张量进行操作。
def customMSE(y_true, y_pred):
'''
Correct predictions of 0 do not affect performance.
'''
y_true_ = y_true[tf.where(y_true>0.1)] # Your threshold condition
y_pred_ = y_pred[tf.where(y_true>0.1)]
mse = K.mean(K.square(y_pred_ - y_true_), axis=1)
return mse
但是当我这样做时,我会返回错误
ValueError: Shape must be rank 1 but is rank 3 for '{{node customMSE/strided_slice}} = StridedSlice[Index=DT_INT64, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](cond_2/Identity_1, customMSE/strided_slice/stack, customMSE/strided_slice/stack_1, customMSE/strided_slice/Cast)' with input shapes: [?,?,?,?], [1,?,4], [1,?,4], [1].```
【问题讨论】:
-
Loss 函数将在图形模式下执行,numpy 函数在那里不可用。请改用
tf.where(import tensorflow as tf)。 -
哦。在第一次调用 tf.where 时,我返回一个值错误
Shape must be rank 1 but is rank 3。不知道该怎么做。它与y_true[tf.where(y_true>01.)]@Kaveh 有关 -
我已经用 tf.where 完全替换了 np.where。那么我是否必须重塑输入张量,使用 tf 为 1D? @Kaveh
-
你想在自定义损失函数中做什么?
-
@Kaveh 我想计算 MSE,但仅适用于真相不是 0 或接近 0 的预测。我想忽略这些。
标签: python tensorflow keras conv-neural-network loss-function