【问题标题】:How do I perform a differentiable operation selection in TensorFlow?如何在 TensorFlow 中执行可微分运算选择?
【发布时间】:2017-07-03 09:30:29
【问题描述】:

我正在尝试生成一个基于标量输入的数学运算选择 nn 模型。该操作是根据 nn 产生的 softmax 结果来选择的。然后必须将此操作应用于标量输入以产生最终输出。到目前为止,我已经想出在 softmax 输出上应用 argmax 和 onehot 以生成一个掩码,然后将它应用于所有可能执行的操作的连接值矩阵(如下面的伪代码所示)。问题是 argmax 和 onehot 似乎都不可区分。我是新手,所以任何人都会受到高度赞赏。提前致谢。

    #perform softmax    
    logits  = tf.matmul(current_input, W) + b
    softmax = tf.nn.softmax(logits)

    #perform all possible operations on the input
    op_1_val = tf_op_1(current_input)
    op_2_val = tf_op_2(current_input)
    op_3_val = tf_op_2(current_input)
    values = tf.concat([op_1_val, op_2_val, op_3_val], 1)

    #create a mask
    argmax  = tf.argmax(softmax, 1)
    mask  = tf.one_hot(argmax, num_of_operations)

    #produce the input, by masking out those operation results which have not been selected
    output = values * mask

【问题讨论】:

    标签: machine-learning tensorflow neural-network recurrent-neural-network calculus


    【解决方案1】:

    我相信这是不可能的。这类似于paper 中描述的 Hard Attention。图像字幕中使用了硬注意力,以允许模型在每个步骤中仅关注图像的特定部分。硬注意力是不可区分的,但有两种方法可以解决这个问题:

    1- 使用强化学习 (RL):强化学习用于训练做出决策的模型。即使损失函数不会将任何梯度反向传播到用于决策的 softmax,您也可以使用 RL 技术来优化决策。举个简单的例子,您可以将损失视为惩罚,并在 softmax 层中具有最大值的节点发送与惩罚成比例的策略梯度,以便在决策错误时降低决策的分数(结果损失惨重)。

    2- 使用类似软注意力的东西:不要只选择一个操作,而是将它们与基于 softmax 的权重混合。所以而不是:

    output = values * mask
    

    用途:

    output = values * softmax
    

    现在,根据 softmax 选择它们的程度,这些操作将收敛到零。与 RL 相比,这更容易训练,但如果您必须从最终结果中完全删除未选择的操作(将它们完全设置为零),它将无法工作。

    这是另一个关于硬注意力和软注意力的答案,您可能会觉得有帮助:https://stackoverflow.com/a/35852153/6938290

    【讨论】:

    • 非常感谢您的回答。我已经对 softmax 本身进行了试验,如果 logits 之间的差异非常大,它看起来会产生接近 1 的最大值和 0 的较小值。使用 softmax = tf.nn.softmax(10000*logits) 之类的东西生成我的掩码以立即产生 logits 的巨大差异是否合理,或者我应该训练 nn 通过权重拾取它?
    • 我认为你应该训练 nn 通过权重来获取它。与像 1000 这样的大数相乘的唯一问题是它会在训练过程的早期使 softmax 饱和。对于较大的 logits 值,softmax 的梯度几乎为 0,因此即使在训练过程开始时,softmax 后面的权重也会得到较少的更新。虽然乘以 1000 应该会增加梯度,但它可能与 softmax 中梯度的指数衰减不匹配。
    猜你喜欢
    • 1970-01-01
    • 2018-09-25
    • 2021-10-06
    • 2019-03-09
    • 1970-01-01
    • 2018-06-18
    • 1970-01-01
    • 1970-01-01
    • 2013-04-16
    相关资源
    最近更新 更多