【问题标题】:How to batch compute cross entropy for pointer networks?如何批量计算指针网络的交叉熵?
【发布时间】:2018-03-26 04:47:57
【问题描述】:

在指针网络中,输出 logits 超过输入的长度。使用此类批次意味着将输入填充到批次输入的最大长度。现在,这一切都很好,直到我们必须计算损失。目前我正在做的是:

logits = stabilize(logits(inputs))     #[batch, max_length]. subtract max(logits) to stabilize
masks = masks(inputs)     #[batch, max_length]. 1 for actual inputs, 0 for padded locations
exp_logits = exp(logits)
exp_logits_masked = exp_logits*masks
probs = exp_logits_masked/sum(exp_logits_masked)

现在我使用这些概率来计算交叉熵

cross_entropy = sum_over_batches(probs[correct_class])

我能做得比这更好吗?关于指针网络的人通常如何完成它的任何想法?

如果我没有可变大小的输入,这一切都可以在 logits 和标签上使用可调用的tf.nn.softmax_cross_entropy_with_logits 来实现(这是高度优化的),但是可变长度会产生错误的结果,因为 softmax 计算的分母对于每个填充都大 1在输入中。

【问题讨论】:

    标签: python-3.x tensorflow deep-learning softmax cross-entropy


    【解决方案1】:

    您的方法看起来很到位,据我所知,这也是在 RNN 单元中实现的方式。请注意,1x 的导数 = dx,0x 的导数 = 0。这会产生您想要的结果,因为您在网络末端对梯度进行求和/平均。

    您可能会考虑的唯一事情是根据掩码值的数量重新调整损失。您可能会注意到,当有 0 个掩码值时,您的渐变的幅度与使用许多掩码值时的幅度略有不同。我不清楚这是否会产生重大影响,但可能会产生非常小的影响。

    否则,我自己也使用同样的技术取得了巨大的成功,所以我在这里说你走在正确的轨道上。

    【讨论】:

      猜你喜欢
      • 2022-01-05
      • 2021-03-26
      • 1970-01-01
      • 2017-10-22
      • 1970-01-01
      • 2017-07-20
      • 2021-10-07
      • 2019-01-28
      • 1970-01-01
      相关资源
      最近更新 更多