【问题标题】:How to implement multi class dice loss function without using argmax function(argmax is not differentiable)?如何在不使用 argmax 函数的情况下实现多类骰子损失函数(argmax 不可微)?
【发布时间】:2019-09-30 15:12:19
【问题描述】:

我正在尝试在 tensorflow 中实现多类骰子损失函数。由于它是多类骰子,我需要将每个类的概率转换为它的 one-hot 形式。例如,如果我的网络输出这些概率:
[0.2, 0.6, 0.1, 0.1](假设 4 个类)
我需要将其转换为:
[0 1 0 0]
这可以通过使用 tf.argmax 后跟 tf.one_hot

来完成
def generalized_dice_loss(labels, logits):
 #labels shape [batch_size,128,128,64,1] dtype=float32
 #logits shape [batch_size,128,128,64,7] dtype=float32
 labels=tf.cast(labels,tf.int32)
 smooth = tf.constant(1e-17)
 shape = tf.TensorShape(logits.shape).as_list()
 depth = int(shape[-1])
 labels = tf.one_hot(labels, depth, dtype=tf.int32,axis=4)
 labels = tf.squeeze(labels, axis=5)
 logits = tf.argmax(logits,axis=4)
 logits = tf.one_hot(logits, depth, dtype=tf.int32,axis=4)
 numerator = tf.reduce_sum(labels * logits, axis=[1, 2, 3])
 denominator = tf.reduce_sum(labels + logits, axis=[1, 2, 3])
 numerator=tf.cast(numerator,tf.float32)
 denominator=tf.cast(denominator,tf.float32)
 loss = tf.reduce_mean(1.0 - 2.0*(numerator + smooth)/(denominator + smooth))
 return loss

问题是,tf.argmax 不可微,会抛出错误:

ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

如何解决这个问题?我们可以不使用 tf.argmax 做同样的事情吗?

【问题讨论】:

    标签: python-3.x tensorflow keras deep-learning


    【解决方案1】:

    看看How is the smooth dice loss differentiable?。您无需进行转换(将[0.2, 0.6, 0.1, 0.1] 转换为[0 1 0 0])。只需将它们保留为 0 到 1 之间的连续值即可。

    如果我理解正确,损失函数只是实现您预期目标的替代品。即使不一样,只要是好的代理就可以了(否则不可微)。

    在评估时,请随时使用tf.argmax 获取真实指标。

    【讨论】:

      猜你喜欢
      • 2021-03-15
      • 2022-06-12
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-09-08
      • 2016-10-05
      • 1970-01-01
      • 2023-03-16
      相关资源
      最近更新 更多