【问题标题】:Loading data from generator using tf.data.Dataset.from_generator()使用 tf.data.Dataset.from_generator() 从生成器加载数据
【发布时间】:2021-05-06 12:35:09
【问题描述】:

我想为我的度量学习模型加载数据,数据生成函数是get_data()函数

def get_data():
    def my_generator():
        for i in range(10):
            anchor = list(np.expand_dims(cv2.imread('img1'), axis=0))
            positive = list(np.expand_dims(cv2.imread('img2'), axis=0)
            true = 0
            a = (true, anchor, positive)
            yield a

    return tf.data.Dataset.from_generator(
        my_generator,
        output_types=(tf.int64, tf.Tensor, tf.Tensor),
        output_shapes=(1, (1, 256, 256, 3), (1, 256, 256, 3))
    )

dataset = get_data()

当我运行此代码时,我收到以下错误。我尝试将其他一些参数传递给output_types,例如 tf.float64,但它也不起作用。我想我在形状上做错了什么,但我不知道是什么。

TypeError:无法将值 转换为 TensorFlow DType。

任何帮助表示赞赏:)

【问题讨论】:

    标签: python tensorflow deep-learning generator


    【解决方案1】:

    正如我所想,问题出在形状上,这对我有用

        return tf.data.Dataset.from_generator(
            my_generator,
            output_types=(tf.float64, tf.float64, tf.float64),
            output_shapes=(tf.TensorShape(None), tf.TensorShape((1, 256, 256, 3)), 
                tf.TensorShape((1, 256, 256, 3))))
    

    【讨论】:

      猜你喜欢
      • 2019-02-25
      • 1970-01-01
      • 1970-01-01
      • 2017-07-23
      • 1970-01-01
      • 1970-01-01
      • 2020-11-30
      • 1970-01-01
      • 2021-11-18
      相关资源
      最近更新 更多