【问题标题】:How to slice according to batch in the tensorflow array?tensorflow数组中如何按batch进行切片?
【发布时间】:2022-01-02 15:27:52
【问题描述】:

我有一个数组 output 和一个 ID subject_ids

output = [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]]

subject_ids = [[0, 1], [1, 2], [0, 2]]

ID中的数字分别代表开始和结束位置,然后我想根据开始和结束位置得到它们之间的向量。

例如,在这种情况下,我应该得到[[1, 2, 3], [4, 5, 6]][[4, 5, 6], [7, 8, 9]][[1, 2, 3], [4, 5, 6], [7, 8, 9]]

我该怎么办?我试过tf.slicetf.gather,但似乎没有用。

【问题讨论】:

    标签: python tensorflow tensor ragged ragged-tensors


    【解决方案1】:

    如果您只想使用 Tensorflow,请尝试将 tf.gathertf.rangetf.ragged.stack 结合使用:

    import tensorflow as tf
    
    output = tf.constant([
                          [[1, 2, 3]], 
                          [[4, 5, 6]], 
                          [[7, 8, 9]]
                          ])
    
    subject_ids = tf.constant([[0, 1], [1, 2], [0, 2]])
    
    ragged_ouput = tf.ragged.stack([tf.gather(output, tf.range(subject_ids[i, 0], subject_ids[i, 1] + 1)) for i in tf.range(0, tf.shape(subject_ids)[0])], axis=0)
    ragged_ouput = tf.squeeze(ragged_ouput, axis=2)
    print(ragged_ouput)
    
    [[[1, 2, 3], [4, 5, 6]], [[4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
    

    更新 1:

    import tensorflow as tf
    tf.config.run_functions_eagerly(True)
    
    output = tf.constant([
                          [[1, 2, 3]], 
                          [[4, 5, 6]], 
                          [[7, 8, 9]]
                          ])
    
    subject_ids = tf.constant([[0, 1], [1, 2], [0, 2]])
    
    def slice_tensor(x):
      return tf.ragged.stack([tf.gather(output, tf.range(x[0], x[1] + 1))], axis=0)
    
    ragged_ouput = tf.map_fn(slice_tensor, subject_ids, fn_output_signature=tf.RaggedTensorSpec(shape=[1, None, None, 3],
                                                                        dtype=tf.int32,
                                                                        ragged_rank=2,
                                                                        row_splits_dtype=tf.int64))
    ragged_ouput = tf.squeeze(ragged_ouput, axis=1)
    tf.print(ragged_ouput, summarize=-1)
    
    [[[[1, 2, 3]], [[4, 5, 6]]], [[[4, 5, 6]], [[7, 8, 9]]], [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]]]]
    

    【讨论】:

    • 抱歉,我运行了这个但得到了TypeError: Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn. Can't loops be used in tensorflow's model?
    • 当然可以使用,但要看在哪里。问题是您在问题中显示了两行代码而没有提及其他任何内容,并且您期望一个完整的工作示例适合您现有代码的其余部分(我们都看不到)。显然,它可能不起作用,因为您未能提供更多详细信息。无论如何,我用 tf.map_fn 的例子更新了我的答案
    • 是的,一开始是想把问题简单化,所以没有贴出所有的代码。非常感谢 !这对我很有帮助。
    【解决方案2】:

    怎么样

    >>> [output[np.arange(x, y+1)] for x, y in subject_ids] 
    [array([[[1, 2, 3]],
            [[4, 5, 6]]]),
            
     array([[[4, 5, 6]],
            [[7, 8, 9]]]),
            
     array([[[1, 2, 3]],
            [[4, 5, 6]],
            [[7, 8, 9]]])]
    

    【讨论】:

    • 对不起,我运行了这个但也得到了TypeError: Tensor objects are only iterable when eager execution is enabled. To iterate over this tensor use tf.map_fn. Can't loops be used in tensorflow's model?
    猜你喜欢
    • 2020-12-05
    • 2017-01-02
    • 1970-01-01
    • 1970-01-01
    • 2010-11-23
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多