【问题标题】:Tensorflow, how to concatenate multiple datasets with varying batch sizesTensorflow,如何连接具有不同批量大小的多个数据集
【发布时间】:2019-03-24 15:06:05
【问题描述】:

想象一下我有:

  • 数据集 1 和数据 [5, 5, 5, 5, 5]
  • 数据集 2 和数据 [4, 4]

我想从两个数据集中获取批次并将它们连接起来,以便获得大小为 3 的批次,其中:

  • 我读取数据集 1,批量大小为 2
  • 我读取数据集 2,批量大小为 1。

如果某些数据集先被清空,我还想读取最后一批。 在这种情况下,我会得到 [5, 5, 4], [5, 5, 4], [5] 作为我的最终结果。

我该怎么做? 我在这里看到了答案:Tensorflow how to generate unbalanced combined data sets

这是一个很好的尝试,但如果其中一个数据集在其他数据集之前被清空,则它不起作用(因为当您尝试从首先清空的数据集中获取元素时,tf.errors.OutOfRangeError 会抢先输出,而我不要得到最后一批)。因此我只得到 [5, 5, 4], [5, 5, 4]

我想过用tf.contrib.data.choose_from_datasets

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)

这种工作但相当不优雅(有unbatch和batch);另外,我对批次中的确切内容并没有很好的控制权。 (例如,如果 ds1 是 [7] * 7,批量大小为 2,ds2 是 [2, 2],批量大小为 1,我会得到 [7, 7, 1], [7, 7, 1], [7 , 7, 7]。但是如果我真的想要 [7, 7, 1], [7, 7, 1], [7, 7], [7] 怎么办?即保持每个数据集中的元素数量固定.

还有其他更好的解决方案吗?

我的另一个想法是尝试使用tf.data.Dataset.flat_map

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
  concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
  datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
  datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
  return concat(datasets)
dataset = (tf.data.Dataset
           .zip((ds1, ds2))
           .flat_map(_concat_and_batch)
           .batch(sum(batch_sizes)))

但它似乎不起作用..

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    这里有一个解决方案。它有一些问题,但我希望它能满足您的需求。

    想法如下:将两个数据集分别批处理,将它们压缩在一起,并执行一个映射函数将每个压缩元组组合成一个批处理(到目前为止,这类似于this 中的建议和this 回答。)

    正如您所注意到的,问题在于压缩仅适用于长度相同的两个数据集。否则,一个数据集在另一个数据集之前被消耗,剩余未消耗的元素不被使用。

    我的(有点老套的)解决方案是将另一个无限的虚拟数据集连接到两个数据集。这个虚拟数据集仅包含您知道不会出现在真实数据集中的值。这消除了拉链的问题。但是,您需要摆脱所有虚拟元素。这可以通过过滤和映射轻松完成。

    import tensorflow as tf
    
    ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
    ds2 = tf.data.Dataset.from_tensor_slices([4,4])
    
    # we assume that this value will never occur in `ds1` and `ds2`:
    UNUSED_VALUE = -1 
    
    # an infinite dummy dataset:
    dummy_ds = tf.data.Dataset.from_tensors(UNUSED_VALUE).repeat() 
    
    # make `ds1` and `ds2` infinite:
    ds1 = ds1.concatenate(dummy_ds)
    ds2 = ds2.concatenate(dummy_ds)
    
    ds1 = ds1.batch(2)
    ds2 = ds2.batch(1)
    
    # this is the solution mentioned in the links above
    ds = tf.data.Dataset.zip((ds1,ds2))
    ds = ds.map(lambda x1, x2: tf.concat((x1,x2),0))
    
    # filter the infinite dummy tail:
    ds = ds.filter(lambda x: tf.reduce_any(tf.not_equal(x,UNUSED_VALUE)))
    
    # filter from batches the dummy elements:
    ds = ds.map(lambda x: tf.boolean_mask(x,tf.not_equal(x,UNUSED_VALUE)))
    

    这个解决方案有两个主要问题:

    (1) 我们需要有一个UNUSED_VALUE 的值,我们确定它不会出现在数据集中。我怀疑有一种解决方法,可能是通过使虚拟数据集由空张量(而不是具有恒定值的张量)组成,但我还不知道如何做到这一点。

    (2) 尽管此数据集的元素数量有限,但以下循环永远不会终止:

    iter = ds.make_one_shot_iterator()
    batch = iter.get_next()
    sess = tf.Session()
    while True:
        print(sess.run(batch))
    

    原因是迭代器不断过滤掉虚拟示例,不知道何时停止。这可以通过将上面的repeat() 调用更改为repeat(n) 来解决,其中n 是一个您知道的比两个数据集的长度差长的数字。

    【讨论】:

    • 这是一个很好的尝试,感谢您的想法。我觉得这带来了太多开销,而且必须确切知道何时停止迭代器很麻烦
    【解决方案2】:

    如果您不介意在构建新数据集期间运行会话,可以执行以下操作:

    import tensorflow as tf
    import numpy as np
    
    ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
    ds2 = tf.data.Dataset.from_tensor_slices([4,4])
    
    ds1 = ds1.batch(2)
    ds2 = ds2.batch(1)
    
    iter1 = ds1.make_one_shot_iterator()
    iter2 = ds2.make_one_shot_iterator()
    
    batch1 = iter1.get_next()
    batch2 = iter2.get_next()
    
    sess = tf.Session()
    
    # define a generator that will sess.run both datasets, and will return the concatenation of both
    def GetBatch():
        while True:
            try:
                b1 = sess.run(batch1)
            except tf.errors.OutOfRangeError:
                b1 = None
            try:
                b2 = sess.run(batch2)
            except tf.errors.OutOfRangeError:
                b2 = None
            if (b1 is None) and (b2 is None):
                break
            elif b1 is None:
                yield b2
            elif b2 is None:
                yield b1
            else:
                yield np.concatenate((b1,b2))
    
    # create a dataset from the above generator
    ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)
    

    请注意,如果您愿意,可以将上述会话隐藏\封装(例如,在函数内部),例如:

    iter = ds.make_one_shot_iterator()
    batch = iter.get_next()
    
    sess2 = tf.Session()
    
    while True:
        print(sess2.run(batch))
    

    【讨论】:

    • 好吧,也许我会这样做。但是有一个问题:您知道在数据集创建过程中使用会话是否会产生开销吗?使用 tf.data.Dataset 函数是否比仅使用 numpy 更快?因为否则我可以直接在 numpy 中执行我想要的生成器,然后在最后使用 tf.data.Dataset.from_generator ,不是吗?你怎么看?
    • 我不确定,但我相信这会产生开销,至少在您使用 GPU 时是这样(因为这会强制原本可能已经放置在 GPU 上的数据通过 CPU 内存) .但我确实认为这个解决方案比仅使用 numpy 更好,因为它将在队列\多线程中执行数据集预处理,这是对资源的更好利用。
    【解决方案3】:

    这是一个解决方案,要求您使用“控制输入”来选择要使用的批次,然后您根据首先使用哪个数据集来决定这一点。这可以使用抛出的异常来检测。

    为了解释这个解决方案,我将首先提出一个不起作用的尝试。

    尝试的解决方案 #1

    import tensorflow as tf
    
    ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
    ds2 = tf.data.Dataset.from_tensor_slices([4,4])
    
    ds1 = ds1.batch(2)
    ds2 = ds2.batch(1)
    
    iter1 = ds1.make_one_shot_iterator()
    iter2 = ds2.make_one_shot_iterator()
    
    batch1 = iter1.get_next(name='batch1')
    batch2 = iter2.get_next(name='batch2')
    batch12 = tf.concat((batch1, batch2), 0)
    
    # this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
    which_batch = tf.placeholder(tf.int32)
    
    batch = tf.cond(
                   tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                           lambda:batch12,
            lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                           lambda:batch1,
            lambda:batch2)) # else, use `batch2`
    
    sess = tf.Session()
    
    which = 0 # this value will be fed into the control placeholder `which_batch`
    while True:
        try:
            print(sess.run(batch,feed_dict={which_batch:which}))
        except tf.errors.OutOfRangeError as e:
            # use the error to detect which dataset was consumed, and update `which` accordingly
            if which==0:
                if 'batch2' in e.op.name:
                    which = 1
                else:
                    which = 2
            else:
                break
    

    这个解决方案不起作用,因为对于which_batch 的任何值,tf.cond() 命令将评估其分支的所有前身(请参阅this answer)。因此,即使 which_batch 的值为 1,batch2 也会被计算并抛出 OutOfRangeError

    尝试的解决方案 #2

    这个问题可以通过将batch1batch2batch12的定义移动到函数中来解决。

    import tensorflow as tf
    
    ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
    ds2 = tf.data.Dataset.from_tensor_slices([4,4])
    
    ds1 = ds1.batch(2)
    ds2 = ds2.batch(1)
    
    iter1 = ds1.make_one_shot_iterator()
    iter2 = ds2.make_one_shot_iterator()
    
    def get_batch1():
        batch1 = iter1.get_next(name='batch1')
        return batch1
    
    def get_batch2():
        batch2 = iter2.get_next(name='batch2')
        return batch2
    
    def get_batch12():
        batch1 = iter1.get_next(name='batch1_')
        batch2 = iter2.get_next(name='batch2_')
        batch12 = tf.concat((batch1, batch2), 0)
        return batch12
    
    # this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
    which_batch = tf.placeholder(tf.int32)
    
    batch = tf.cond(
                   tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                           get_batch12,
            lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                           get_batch1,
            get_batch2)) # elif `which_batch`==2, use `batch2`
    
    sess = tf.Session()
    
    which = 0 # this value will be fed into the control placeholder `which_batch`
    while True:
        try:
            print(sess.run(batch,feed_dict={which_batch:which}))
        except tf.errors.OutOfRangeError as e:
            # use the error to detect which dataset was consumed, and update `which` accordingly
            if which==0:
                if 'batch2' in e.op.name:
                    which = 1
                else:
                    which = 2
            else:
                break
    

    但是,这也不起作用。原因是在形成batch12 并消耗数据集ds2 的那一步,然后我们从数据集ds1 中取出批处理并“丢弃”它而不使用它。

    解决方案

    我们需要一种机制来确保在使用其他数据集的情况下不会“丢弃”任何批次。我们可以通过定义一个变量来做到这一点,该变量将被分配当前批次的ds1,但仅在尝试获得batch12之前立即。否则,此变量将保留其先前的值。然后,如果 batch12 由于 ds1 被消耗而失败,那么这个分配将失败并且 batch2 没有被丢弃,我们下次可以使用它。否则,如果batch12 由于ds2 被消耗而失败,那么我们在我们定义的变量中拥有batch1 的备份,使用此备份后我们可以继续获取batch1

    import tensorflow as tf
    
    ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
    ds2 = tf.data.Dataset.from_tensor_slices([4,4])
    
    ds1 = ds1.batch(2)
    ds2 = ds2.batch(1)
    
    iter1 = ds1.make_one_shot_iterator()
    iter2 = ds2.make_one_shot_iterator()
    
    # this variable will store a backup of `batch1`, in case it is dropped
    batch1_backup = tf.Variable(0, trainable=False, validate_shape=False)
    
    def get_batch12():
        batch1 = iter1.get_next(name='batch1')
    
        # form the combined batch `batch12` only after backing-up `batch1`
        with tf.control_dependencies([tf.assign(batch1_backup, batch1, validate_shape=False)]):
            batch2 = iter2.get_next(name='batch2')
            batch12 = tf.concat((batch1, batch2), 0)
        return batch12
    
    def get_batch1():
        batch1 = iter1.get_next()
        return batch1
    
    def get_batch2():
        batch2 = iter2.get_next()
        return batch2
    
    # this is a "control" placeholder. Its value determines whether to use `batch12`, `batch1_backup`, `batch1`, or `batch2`
    which_batch = tf.Variable(0,trainable=False)
    
    batch = tf.cond(
                   tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                           get_batch12,
            lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1_backup`
                           lambda:batch1_backup,
            lambda:tf.cond(tf.equal(which_batch,2), # elif `which_batch`==2, use `batch1`
                           get_batch1,
           get_batch2))) # else, use `batch2`
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    which = 0  # this value will be fed into the control placeholder
    while True:
        try:
            print(sess.run(batch,feed_dict={which_batch:which}))
    
            # if just used `batch1_backup`, proceed with `batch1`
            if which==1:
                which = 2
        except tf.errors.OutOfRangeError as e:
            # use the error to detect which dataset was consumed, and update `which` accordingly
            if which == 0:
                if 'batch2' in e.op.name:
                    which = 1
                else:
                    which = 3
            else:
                break
    

    【讨论】:

      猜你喜欢
      • 2019-02-06
      • 1970-01-01
      • 1970-01-01
      • 2020-08-09
      • 2019-10-16
      • 2020-04-09
      • 1970-01-01
      • 1970-01-01
      • 2020-06-04
      相关资源
      最近更新 更多