【问题标题】:Tensorflow how to generate unbalanced combined data setsTensorflow如何生成不平衡的组合数据集
【发布时间】:2018-06-24 14:49:46
【问题描述】:

我对新的数据集 API (tensorflow 1.4) 有疑问。我有两个数据集,我需要创建一个组合的不平衡数据集,即 每个批次应包含来自第一个数据集的一定数量的元素和来自第二个数据集的一定数量的元素。例如,

dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([1,1,1,1,1,1]
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([2,2,2,2,2,2]))

假设批次大小为 4,我希望组合数据集中的批次看起来像 [1,1,1,2]。我知道如何使用 zip 和 flat_map 生成平衡的数据集 但我对这个不知所措。

提前致谢!

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    为了解决这个问题,我的解决方案是单独批处理数据集,压缩它们,然后在生成的数据集上映射tf.concat 运算符。

    在您的示例中,它会给出类似(我将第二个数据集重命名为 dataset2):

    def concat(*tensor_list):
        return tf.concat(tensor_list, axis=0)
    
    zipped_ds = tf.data.Dataset.zip((dataset1.batch(3), dataset2))
    unbalanced_ds = zipped_ds.map(concat)
    

    如果数据集是张量的嵌套结构,可以使用以下版本的 concat :

    def concat(*ds_elements):
        #Create one empty list for each component of the dataset
        lists = [[] for _ in ds_elements[0]]
        for element in ds_elements:
            for i, tensor in enumerate(element):
                #For each element, add all its component to the associated list
                lists[i].append(tensor)
    
        #Concatenate each component list
        return tuple(tf.concat(l, axis=0) for l in lists)
    

    如果所有数据集元素(您要组合的数据集的一部分)都是只有最外层维度(相对批量大小)不同的张量,则有效。它为数据集元素的每个组件构建一个列表,并将这些组件相互独立地连接起来。

    处理一层嵌套。如果你需要更多,你可以使用递归来解包嵌套的嵌套,但它可能会给出一个不太干净的计算图......

    【讨论】:

      猜你喜欢
      • 2016-04-02
      • 2014-12-22
      • 2022-07-17
      • 2018-11-12
      • 2016-12-31
      • 2020-06-07
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多