【发布时间】:2020-01-01 18:19:12
【问题描述】:
我正在为稀疏输入数据编写二进制分类器,我想将输入 0 视为数据不存在的指示,而不是值肯定为 0 的指示。我最初使用的是tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(...)) ,但它对误报的处罚过于严厉。
我成功编写了一个损失函数,如下所示,它提供了我想要的行为,但它慢了几个数量级,我需要找到一种方法来窃取一些性能。
def loss(labels, logits):
labels = tf.reshape(labels, shape=(-1,))
logits = tf.reshape(logits, shape=(-1,))
pairs = tf.stack([labels, logits], axis=1)
return tf.reduce_mean(tf.map_fn(
lambda x: tf.cond(
x[0] < x[1], # x[0] is in {0,1} and x[1] is in (0,1)
lambda: 0.0, # thus the condition is true iff x[0] == 0
lambda: tf.nn.sigmoid_cross_entropy_with_logits(
labels=x[0],
logits=x[1])),
pairs))
【问题讨论】:
标签: python tensorflow optimization loss-function