【问题标题】:Dataset.batch doesn't work as expected with a zipped datasetDataset.batch 不能按预期使用压缩数据集
【发布时间】:2022-09-02 22:53:10
【问题描述】:

我有一个这样的数据集:

a = tf.data.Dataset.range(1, 16)
b = tf.data.Dataset.range(16, 32)
zipped = tf.data.Dataset.zip((a, b))
list(zipped.as_numpy_iterator())

# output: 
[(0, 16),
 (1, 17),
 (2, 18),
 (3, 19),
 (4, 20),
 (5, 21),
 (6, 22),
 (7, 23),
 (8, 24),
 (9, 25),
 (10, 26),
 (11, 27),
 (12, 28),
 (13, 29),
 (14, 30),
 (15, 31)]

当我对其应用batch(4) 时,预期结果是一个批次数组,其中每个批次包含四个元组:

[[(0, 16), (1, 17), (2, 18), (3, 19)],
 [(4, 20), (5, 21), (6, 22), (7, 23)],
 [(9, 24), (10, 25), (10, 26), (11, 27)],
 [(12, 28), (13, 29), (14, 30), (15, 31)]]

但这是我收到的:

batched = zipped.batch(4)
list(batched.as_numpy_iterator())

# Output:
[(array([0, 1, 2, 3]), array([16, 17, 18, 19])), 
 (array([4, 5, 6, 7]), array([20, 21, 22, 23])), 
 (array([ 8,  9, 10, 11]), array([24, 25, 26, 27])), 
 (array([12, 13, 14, 15]), array([28, 29, 30, 31]))]

我正在关注这个tutorial,他执行相同的步骤,但以某种方式获得了正确的输出。


更新:根据文档,这是预期的行为:

结果元素的组件将有一个额外的维度,这将是 batch_size

但这没有任何意义。据我了解,数据集是数据片段的列表。这些数据的形状无关紧要,当我们对它进行批处理时,我们会将元素 [无论它们的形状是什么] 组合成批处理,因此它应该始终将新维度插入到第二个位置 ((length, a, b, c) -> (length', batch_size, a, b, c))。

所以我的问题是:我想知道以这种方式实现batch() 的目的是什么?什么是我所描述的替代方案?

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    您可以尝试做的一件事是这样的:

    import tensorflow as tf
    
    a = tf.data.Dataset.range(16)
    b = tf.data.Dataset.range(16, 32)
    zipped = tf.data.Dataset.zip((a, b)).batch(4).map(lambda x, y: tf.transpose([x, y]))
    
    list(zipped.as_numpy_iterator())
    
    [array([[ 0, 16],
            [ 1, 17],
            [ 2, 18],
            [ 3, 19]]), 
     array([[ 4, 20],
            [ 5, 21],
            [ 6, 22],
            [ 7, 23]]), 
     array([[ 8, 24],
            [ 9, 25],
            [10, 26],
            [11, 27]]), 
     array([[12, 28],
            [13, 29],
            [14, 30],
            [15, 31]])]
    

    但它们仍然不是元组。或者:

    zipped = tf.data.Dataset.zip((a, b)).batch(4).map(lambda x, y: tf.unstack(tf.transpose([x, y]), num = 4))
    
    [(array([ 0, 16]), array([ 1, 17]), array([ 2, 18]), array([ 3, 19])), (array([ 4, 20]), array([ 5, 21]), array([ 6, 22]), array([ 7, 23])), (array([ 8, 24]), array([ 9, 25]), array([10, 26]), array([11, 27])), (array([12, 28]), array([13, 29]), array([14, 30]), array([15, 31]))]
    

    【讨论】:

      【解决方案2】:

      您可以使用多个batch

      a = tf.data.Dataset.range(16)
      b = tf.data.Dataset.range(16, 32)
      zipped = tf.data.Dataset.zip((a, b))
      batched = zipped.batch(1).batch(4).map(lambda x, y: tf.concat([x, y], 1))
      list(batched.as_numpy_iterator())
      # [array([[ 0, 16],
      #         [ 1, 17],
      #         [ 2, 18],
      #         [ 3, 19]]),
      #  array([[ 4, 20],
      #         [ 5, 21],
      #         [ 6, 22],
      #         [ 7, 23]]),
      #  array([[ 8, 24],
      #         [ 9, 25],
      #         [10, 26],
      #         [11, 27]]),
      #  array([[12, 28],
      #         [13, 29],
      #         [14, 30],
      #         [15, 31]])]
      

      要转换为 2D 列表,并且每个项目都是 tuple

      result = [list(map(tuple, item)) for item in batched.as_numpy_iterator()]
      print(result)
      # [
      #     [(0, 16), (1, 17), (2, 18), (3, 19)], 
      #     [(4, 20), (5, 21), (6, 22), (7, 23)], 
      #     [(8, 24), (9, 25), (10, 26), (11, 27)], 
      #     [(12, 28), (13, 29), (14, 30), (15, 31)]
      # ]
      

      解释:

      >>> list(zipped.batch(1).as_numpy_iterator())
      [(array([0]), array([16])),
       (array([1]), array([17])),
       (array([2]), array([18])),
       (array([3]), array([19])),
       ...
       (array([12]), array([28])),
       (array([13]), array([29])),
       (array([14]), array([30])),
       (array([15]), array([31]))]
      
      # now we need to get '.batch(4)'
      >>> list(zipped.batch(1).batch(4).as_numpy_iterator())
      [(array([[0],
               [1],
               [2],
               [3]]),
        array([[16],
               [17],
               [18],
               [19]])),
      ...
       (array([[12],
               [13],
               [14],
               [15]]),
        array([[28],
               [29],
               [30],
               [31]]))]
       
      # tf.concat each batch with axis=1
      >>> zipped.batch(1).batch(4).map(lambda x, y: tf.concat([x, y], 1))
      
      [array([[ 0, 16],
              [ 1, 17],
              [ 2, 18],
              [ 3, 19]]),
       ...
       array([[12, 28],
              [13, 29],
              [14, 30],
              [15, 31]])]
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2022-11-30
        • 2019-01-12
        • 2015-04-14
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多