【问题标题】:How does batching interact with the loss function in TensorFlow?批处理如何与 TensorFlow 中的损失函数交互?
【发布时间】:2017-02-11 16:38:20
【问题描述】:

我正在使用自己的损失函数在 TensorFlow 中训练多目标神经网络,但找不到有关批处理如何与该功能交互的文档。

例如,我的损失函数在下面有 sn-p,它采用张量/预测列表并确保它们的绝对值总和不超过 1:

def fitness(predictions,actual):

    absTensor = tf.abs(predictions)
    sumTensor = tf.reduce_sum(absTensor)
    oneTensor = tf.constant(1.0)

    isGTOne = tf.greater(sumTensor,oneTensor)

    def norm(): return predictions/sumTensor
    def unchanged(): return predictions

    predictions = tf.cond(isGTOne,norm,unchanged)

    etc...

但是当我传递一批估计值时,我觉得这个损失函数正在将整个输入集归一化为此时总和为 1,而不是每个单独的集合总和为 1。即
[[.8,.8],[.8,.8]] -> [[.25,.25],[.25,25]]
而不是想要的
[[.8,.8],[.8,.8]] -> [[.5,.5],[.5,.5]]

谁能澄清或消除我的怀疑?如果这是我的功能当前的工作方式,我该如何更改?

【问题讨论】:

    标签: machine-learning tensorflow neural-network


    【解决方案1】:

    您必须为缩减操作指定缩减轴,否则所有轴都将被缩减。传统上,这是张量的第一个维度。所以,第 2 行应该是这样的:

    sumTensor = tf.reduce_sum(absTensor, 0)
    

    进行更改后,您将遇到另一个问题。 sumTensor 将不再是标量,因此作为 tf.cond 的条件将不再有意义(即,每个批次的条目分支意味着什么?)。你真正想要的是tf.select,因为你真的不想为每个批处理条目分支逻辑。像这样:

    isGTOne = tf.greater(sumTensor,oneTensor)
    
    norm = predictions/sumTensor
    
    predictions = tf.select(isGTOne,norm,predictions)
    

    但是,现在看这个,我什至不会费心对条目进行有条件的规范化。由于您现在以批处理的粒度进行操作,因此我认为您无法通过一次规范化批处理的条目来获得性能。特别是,因为除法并不是真正昂贵的副作用。还不如这样做:

    def fitness(predictions,actual):
    
      absTensor = tf.abs(predictions)
      sumTensor = tf.reduce_sum(absTensor, 0)
    
      predictions = predictions/sumTensor
    
      etc...
    

    希望有帮助!

    【讨论】:

    • 太完美了。谢谢你。文档中是否有某处谈论这种行为?我想通读一遍,以确保没有其他意外发生
    • 您在寻找关于什么行为的文档? tf.select docs 很有用。
    猜你喜欢
    • 1970-01-01
    • 2019-08-17
    • 1970-01-01
    • 2019-01-20
    • 2020-09-27
    • 2015-08-18
    • 1970-01-01
    • 2018-01-29
    • 2018-05-31
    相关资源
    最近更新 更多