【问题标题】:multiply only certain columns of a tensorflow array仅乘以张量流数组的某些列
【发布时间】:2018-09-15 00:18:43
【问题描述】:

我目前正在修改我的一个对象检测神经网络的损失函数。我基本上有两个数组;

y_true:预测标签。形状为 (x, y, z) 的 tf 张量 y_pred:预测值。形状 (x, y, z) 的 tf 张量 - x 维度是批量大小,y 维度是图像中预测对象的数量,z 维度包含类的 one-hot 编码以及边界所述类的盒子。

现在回到真正的问题:我想要做的基本上是将 y_pred 中的前 5 个 z 值与 y_true 中的前 5 个 z 值相乘。所有其他值应保持不受影响。在 numpy 中,它非常简单;

y_pred[:,:,:5] *= y_true[:,:,:5]

我发现在 tensorflow 中很难做到这一点,因为我无法将值分配给原始张量,并且我想保持所有其他值相同。我该如何在 tensorflow 中执行此操作?

【问题讨论】:

    标签: matrix tensorflow slice


    【解决方案1】:

    从 v1.1 开始,Tensorflow 涵盖了这种类似 Numpy 的索引,请参阅 Tensor.getitem

    import tensorflow as tf
    
    with tf.Session() as sess:
        y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
        y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
        print((y_pred[:,:,:5] * y_true[:,:,:5]).eval()) 
        # [[[   1    4    9   16   25]
        #   [ 100  400  900 1600 2500]]]
    

    评论后编辑:

    现在,问题是“*=”部分,即项目分配。这在 Tensorflow 中并不是一个简单的操作。但是,在您的情况下,这可以使用tf.concattf.where 轻松解决(tf.dynamic_partition + tf.dynamic_stitch 可用于更复杂的情况)。

    在下面找到前两个解决方案的快速实现。

    使用 Tensor.getitem 和 tf.concat 的解决方案:

    import tensorflow as tf
    
    with tf.Session() as sess:
        y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
        y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    
        # tf.where can't apply the condition to any axis (see doc).
        # In your case (condition on 2nd axis), we need either to manually broadcast the
        # condition tensor, or transpose the target tensors.
        # Here is a quick demonstration with the 2nd solution:
    
        y_pred_edit = y_pred[:,:,:5] * y_true[:,:,:5]
        y_pred_rest = y_pred[:,:,4:]
    
        y_pred = tf.concat((y_pred_edit, y_pred_rest), axis=2)
        print(y_pred.eval())
        # [[[ 1  4  9 16 25  6  7  8  9 10]]]
    

    使用 tf.where 的解决方案:

    import tensorflow as tf
    
    def select_n_fist_indices(n, batch_size):
        """ Return a list of length batch_size with the n first elements True
            and the rest False, i.e. [*[[True] * n], *[[False] * (batch_size - n)]]. 
        """
        n_ones = tf.ones((n))
        rest_zeros = tf.zeros((batch_size - n))
        indices = tf.cast(tf.concat((n_ones, rest_zeros), axis=0), dtype=tf.bool)
    
        return indices
    
    with tf.Session() as sess:
        y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
        y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    
        # tf.where can't apply the condition to any axis (see doc).
        # In your case (condition on 2nd axis), we need either to manually broadcast the 
        # condition tensor, or transpose the target tensors.
        # Here is a quick demonstration with the 2nd solution:
        y_pred_tranposed = tf.transpose(y_pred, [2, 0, 1])
        y_true_tranposed = tf.transpose(y_true, [2, 0, 1])
    
        edit_indices = select_n_fist_indices(5, tf.shape(y_pred_tranposed)[0])
    
        y_pred_tranposed = tf.where(condition=edit_indices, 
                                    x=y_pred_tranposed * y_true_tranposed, y=y_pred_tranposed)
    
        # Transpose back:  
        y_pred = tf.transpose(y_pred_tranposed, [1, 2, 0])
        print(y_pred.eval())
        # [[[ 1  4  9 16 25  6  7  8  9 10]]]
    

    【讨论】:

    • 很好,但我如何保留其他值?我仍然想要形状(x,y,z)的数组,这会给我一个(x,y,5)的数组?如何将 (x, y, 5) 分配回 y_pred,以便我现在拥有相乘的列和未相乘的边界框?
    • 是的,我的错,后来我注意到我只回答了你的问题的一半。我编辑了一个更完整的解决方案。
    猜你喜欢
    • 1970-01-01
    • 2022-01-06
    • 2019-11-01
    • 1970-01-01
    • 2022-01-24
    • 1970-01-01
    • 2023-03-07
    • 1970-01-01
    • 2018-02-24
    相关资源
    最近更新 更多