【问题标题】:Tensorflow data augmentationTensorFlow 数据增强
【发布时间】:2020-08-07 17:10:31
【问题描述】:

我想转换这个 keras 数据增强工作流程:

datagen = ImageDataGenerator( 
    rescale=1./255,
    rotation_range = 10,
    horizontal_flip = True,
    width_shift_range=0.1,
    height_shift_range=0.1,
    fill_mode = 'nearest')

这是一个代码 sn-p 但两个函数都不起作用,因为它不支持批量维度!

import numpy as np
def augment(x, y):
    x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
    x = tf.keras.preprocessing.image.random_rotation(
    x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
    interpolation_order=1)
    return x, y

X = np.random.random(size=(256, 48, 48, 1))
y = np.random.randint(0, 7, size=(256,))
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.map(augment)
dataset = dataset.batch(16, drop_remainder=False)
dataset = dataset.prefetch(buffer_size=1)

【问题讨论】:

    标签: keras tensorflow2.0 tensorflow-datasets data-augmentation


    【解决方案1】:

    运行您的代码时出现以下错误:AttributeError: 'Tensor' object has no attribute 'ndim'。似乎不可能用tf.data.Dataset 运行augment 函数,因为它无法处理张量。一种解决方法是将您的扩充功能包装在tf.py_function

    import tensorflow as tf
    import numpy as np
    
    def augment(x, y):
        x = x.numpy()
        x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
        x = tf.keras.preprocessing.image.random_rotation(
        x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
        interpolation_order=1)
        return x, y
    
    X = np.random.random(size=(256, 48, 48, 1))
    y = np.random.randint(0, 7, size=(256,))
    
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    dataset = dataset.map(
        lambda x, y: tf.py_function(
            func=augment,
            inp=[x, y],
            Tout=[tf.float32, tf.int64]))
    dataset = dataset.batch(16, drop_remainder=False)
    dataset = dataset.prefetch(buffer_size=1)
    

    上面的代码应该运行没有任何错误。如果你经常需要用tf.py_function 包装你的函数,那么写一个装饰器会很方便(也很干净)。像这样的:

    import tensorflow as tf
    import numpy as np
    
    def map_decorator(func):
        def wrapper(*args):
            return tf.py_function(
                func=func,
                inp=[*args],
                Tout=[a.dtype for a in args])
        return wrapper
    
    @map_decorator
    def augment(x, y):
        x = x.numpy()
        x = tf.keras.preprocessing.image.random_shift(x, 0.1, 0.1)
        x = tf.keras.preprocessing.image.random_rotation(
        x, 10, row_axis=1, col_axis=2, channel_axis=0, fill_mode='nearest', cval=0.0,
        interpolation_order=1)
        return x, y
    
    X = np.random.random(size=(256, 48, 48, 1))
    y = np.random.randint(0, 7, size=(256,))
    
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    dataset = dataset.map(augment)
    dataset = dataset.batch(16, drop_remainder=False)
    dataset = dataset.prefetch(buffer_size=1)
    

    希望对你有帮助!

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-01-14
      • 2016-09-28
      • 2020-10-14
      • 1970-01-01
      • 2018-10-21
      • 2017-11-01
      • 2020-07-20
      • 2019-04-22
      相关资源
      最近更新 更多