【问题标题】:How can I preprocess a tf.data.Dataset using a provided preprocess_input function that expects a tf.Tensor?如何使用提供的需要 tf.Tensor 的 preprocess_input 函数预处理 tf.data.Dataset?
【发布时间】:2021-12-31 01:06:23
【问题描述】:

有点不知所措,我希望使用在 ImageNet 上预训练的 ResNet50 将迁移学习应用于问题。

我已经准备好迁移学习过程,但需要tf.keras.applications.resnet50.preprocess_input 轻松完成的正确格式的数据集。除了它适用于numpy.arraytf.Tensor,我使用image_dataset_from_directory 加载数据,这给了我tf.data.Dataset

有没有一种简单的方法可以使用提供的preprocess_input 函数来预处理我的这种表单中的数据?

或者,函数指定:

图像从 RGB 转换为 BGR,然后每个颜色通道相对于 ImageNet 数据集以零为中心,不进行缩放。

因此,在数据管道中或作为模型的一部分实现此目的的任何其他方式也是可以接受的。

【问题讨论】:

    标签: python tensorflow keras tf.keras


    【解决方案1】:

    您可以使用tf.data.Datasetmap 函数将preprocess_input 函数应用于每批图像:

    import tensorflow as tf
    import pathlib
    import matplotlib.pyplot as plt
    
    dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
    data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
    data_dir = pathlib.Path(data_dir)
    
    batch_size = 32
    
    train_ds = tf.keras.utils.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="training",
      seed=123,
      image_size=(180, 180),
      batch_size=batch_size)
    
    def display(ds):
      images, _ = next(iter(ds.take(1)))
      image = images[0].numpy()
      image /= 255.0
      plt.imshow(image)
    
    def preprocess(images, labels):
      return tf.keras.applications.resnet50.preprocess_input(images), labels
    
    train_ds = train_ds.map(preprocess)
    
    display(train_ds)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2018-09-13
      • 2021-01-29
      • 2020-06-04
      • 2011-02-05
      • 2013-01-24
      相关资源
      最近更新 更多