【问题标题】:Data pipeline in tf.keras with tfrecords or numpytf.keras 中带有 tfrecords 或 numpy 的数据管道
【发布时间】:2019-08-23 22:27:30
【问题描述】:

我想在 Tensorflow 2.0 的 tf.keras 中使用比我的 ram 大的数据训练模型,但教程只显示了带有预定义数据集的示例。

我遵循了这个教程:

Load Images with tf.data,我无法对 numpy 数组或 tfrecords 上的数据进行这项工作。

这是一个将数组转换为 tensorflow 数据集的示例。我想要的是使这项工作适用于多个 numpy 数组文件或多个 tfrecords 文件。

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

# Since the dataset already takes care of batching,
# we don't pass a `batch_size` argument.
model.fit(train_dataset, epochs=3)

【问题讨论】:

    标签: tensorflow tensorflow-datasets tf.keras tensorflow2.0


    【解决方案1】:

    如果您有tfrecords 文件:

    path = ['file1.tfrecords', 'file2.tfrecords', ..., 'fileN.tfrecords']
    dataset = tf.data.Dataset.list_files(path, shuffle=True).repeat()
    dataset = dataset.interleave(lambda filename: tf.data.TFRecordDataset(filename), cycle_length=len(path))
    dataset = dataset.map(parse_function).batch()
    

    parse_function 处理解码和任何类型的扩充。

    如果使用 numpy 数组,您可以从文件名列表或数组列表构建数据集。标签只是一个列表。或者可以在解析单个示例时从文件中获取它们。

    path = #list of numpy arrays
    

    path = os.listdir(path_to files)
    
    dataset = tf.data.Dataset.from_tensor_slices((path, labels))
    dataset = dataset.map(parse_function).batch()
    

    parse_function 处理解码:

    def parse_function(filename, label):  #Both filename and label will be passed if you provided both to from_tensor_slices
        f = tf.read_file(filename)
        image = tf.image.decode_image(f)) 
        image = tf.reshape(image, [H, W, C])
        label = label #or it could be extracted from, for example, filename, or from file itself 
        #do any augmentations here
        return image, label
    

    要解码.npy文件,最好的方法是使用reshape而不使用read_filedecode_raw,但首先加载带有np.load的numpys:

    paths = [np.load(i) for i in ["x1.npy", "x2.npy"]]
    image = tf.reshape(filename, [2])
    

    或尝试使用decode_raw

    f = tf.io.read_file(filename)
    image = tf.io.decode_raw(f, tf.float32)
    

    然后只需将批处理数据集传递给model.fit(dataset)。 TensorFlow 2.0 允许对数据集进行简单的迭代。无需使用迭代器。即使在更高版本的 1.x API 中,您也可以将数据集传递给 .fit 方法

    for example in dataset:
        func(example)
    

    【讨论】:

    • 假设我传递了 numpy 数组文件的文件名。我的 parse_function 应该有什么来加载这些文件?
    • 添加信息以回答
    • 我使用了 parse_function 但使用了 numpy 文件,并且在遍历数据集时抛出了错误。 InvalidArgumentError:断言失败:[无法将字节解码为 JPEG、PNG、GIF 或 BMP] [[{{node Assert/Assert}}]] [[cond_gif]] [[cond_png]] [[decode_image/cond_jpeg]] [操作:IteratorGetNextSync] colab link
    • 如果是 .npy 文件,而不是图像文件,你可以尝试使用 decode_raw。但它可能会产生不一致的结果。另一种方法是使用重塑。更新答案
    • 将 [np.load(i) for i in ["x1.npy", "x2.npy" ...] 与大量 numpy 文件一起使用可能会使用太多内存?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2017-02-02
    • 1970-01-01
    • 2018-08-23
    • 2011-08-14
    • 1970-01-01
    • 1970-01-01
    • 2018-10-07
    相关资源
    最近更新 更多