【问题标题】:Keras - custom loss function - chamfer distanceKeras - 自定义损失函数 - 倒角距离
【发布时间】:2018-07-14 18:40:30
【问题描述】:

我正在尝试使用如下定义的自定义损失函数进行对象分割:

def chamfer_loss_value(y_true, y_pred):           

    # flatten the batch 
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)

    # ==========
    # get chamfer distance sum

    // error here
    y_pred_mask_f = K.cast(K.greater_equal(y_pred_f,0.5), dtype='float32')

    finalChamferDistanceSum = K.sum(y_pred_mask_f * y_true_f, axis=1, keepdims=True)  

    return K.mean(finalChamferDistanceSum)

def chamfer_loss(y_true, y_pred):   
    return chamfer_loss_value(y_true, y_pred)

y_pred_f 是我的 U-net 的结果。 y_true_f 是对 ground truth 标签掩码 x 进行欧几里得距离变换的结果,如下所示:

distTrans = ndimage.distance_transform_edt(1 - x)

要计算倒角距离,您需要将预测图像(理想情况下是具有 1 和 0 的掩码)乘以地面实况距离变换,然后简单地对所有像素求和。为此,我需要通过对y_pred_f 进行阈值处理来获得一个蒙版y_pred_mask_f,然后乘以y_true_f,并对所有像素求和。

y_pred_f 在 [0,1] 中提供连续范围的值,我在评估 y_true_mask_f 时收到错误 None type not supported。我知道损失函数必须是可微的,greater_equalcast 不是。但是,有没有办法在 Keras 中规避这个问题?也许在 Tensorflow 中使用一些解决方法?

【问题讨论】:

  • 它是不可微的,但是如果你使用梯度下降,你需要梯度来优化。你有手动定义的渐变吗?如果是这样,您可以手动计算梯度并将其放入。请参阅此处stackoverflow.com/questions/43839431/…

标签: tensorflow machine-learning neural-network deep-learning keras


【解决方案1】:

嗯,这很棘手。您的错误背后的原因是您的损失与您的网络之间没有持续的依赖关系。为了计算损失 w.r.t 的梯度。对于网络,如果您的输出大于0.5,您的损失必须计算指标的梯度(因为这是您的最终损失值和网络输出y_pred 之间的唯一联系)。这是不可能的,因为该指标是部分恒定的而不是连续的。

可能的解决方案 - 平滑您的指标:

def chamfer_loss_value(y_true, y_pred):           

    # flatten the batch 
    y_true_f = K.batch_flatten(y_true)
    y_pred_f = K.batch_flatten(y_pred)

    y_pred_mask_f = K.sigmoid(y_pred_f - 0.5)

    finalChamferDistanceSum = K.sum(y_pred_mask_f * y_true_f, axis=1, keepdims=True)  

    return K.mean(finalChamferDistanceSum)

因为sigmoid 是步进函数的连续版本。如果您的输出来自sigmoid - 您可以简单地使用y_pred_f 而不是y_pred_mask_f

【讨论】:

    猜你喜欢
    • 2020-02-09
    • 2021-01-14
    • 2020-02-25
    • 2020-12-19
    • 2017-12-18
    • 2020-03-27
    • 1970-01-01
    • 2018-10-28
    • 2017-12-29
    相关资源
    最近更新 更多