【问题标题】:Error when using tf.data.Dataset.from_generator使用 tf.data.Dataset.from_generator 时出错
【发布时间】:2022-01-24 20:51:46
【问题描述】:

我正在尝试使用 tensorflow from_generator 制作 tensorflow 数据集,我很确定我已经制作了一个运行良好的 python 生成器,但是当我尝试将它传递给 from_generator 时,我总是出错。这是我用来创建数据集的一段代码

def dataset_generator(X, Y):
    for idx in range(X.shape[0]):
        img = X[idx, :, :, :]
        labels = Y[idx, :]
        yield img, labels

import tensorflow as tf
ds_generator = dataset_generator(X_data, Y_data)
ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

但是当我运行它时,它总是会产生错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-63-af75191f4a28> in <module>
      1 import tensorflow as tf
      2 ds_generator = dataset_generator(X_data, Y_data)
----> 3 ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)

~/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in from_generator(generator, output_types, output_shapes, args, output_signature)

TypeError: `generator` must be callable.

【问题讨论】:

    标签: python tensorflow tensorflow-datasets


    【解决方案1】:

    您好,您的 gen 函数的问题是您必须通过 args 命令来传递它,而不是像这样的函数

    import tensorflow as tf
    import numpy as np
    
    # Gen Function
    def dataset_generator(X, Y):
        for idx in range(X.shape[0]):
            img = X[idx, :, :, :]
            labels = Y[idx, :]
            yield img, labels
    
    # Created random data for testing
    X_data = np.random.randn(100, 720, 720, 3).astype(np.float32)
    Y_data = tf.one_hot(np.random.randint(0, 30, (100, )), 30)
    
    # Testing function
    ds = tf.data.Dataset.from_generator(
        dataset_generator,
        args=(X_data, Y_data), 
        output_types=(tf.float32, tf.uint8)
    )
    
    # Get output
    next(iter(ds.batch(10).take(1)))
    

    【讨论】:

    • 有效!谢谢,但这是否意味着每次我需要调用 from_generator 时,我都必须通过 args 传递我的生成器参数?
    • 是的,你必须将它作为可调用对象传递,除非你在没有 args 的情况下执行此操作并通过没有 arg 函数运行所有这些
    猜你喜欢
    • 1970-01-01
    • 2018-04-15
    • 1970-01-01
    • 2021-05-06
    • 1970-01-01
    • 2020-11-30
    • 1970-01-01
    • 2019-02-25
    • 1970-01-01
    相关资源
    最近更新 更多