【问题标题】:Apply augmentation on tf.data.Dataset.from_generator()在 tf.data.Dataset.from_generator() 上应用扩充
【发布时间】:2021-06-04 08:56:33
【问题描述】:

我有这个扩充代码:

class CustomAugment(object):
   def __call__(self, sample):        
       sample = self._random_apply(tf.image.flip_left_right, sample, p=0.5)
       sample = self._random_apply(self._color_jitter, sample, p=0.8)
       sample = self._random_apply(self._color_drop, sample, p=0.2)

       return sample

   def _color_jitter(self, x, s=1):
       
       x = tf.image.random_brightness(x, max_delta=0.8*s)
       x = tf.image.random_contrast(x, lower=1-0.8*s, upper=1+0.8*s)
       x = tf.image.random_saturation(x, lower=1-0.8*s, upper=1+0.8*s)
       x = tf.image.random_hue(x, max_delta=0.2*s)
       x = tf.clip_by_value(x, 0, 1)
       return x
   
   def _color_drop(self, x):
       x = tf.image.rgb_to_grayscale(x)
       x = tf.tile(x, [1, 1, 1, 3])
       return x
   
   def _random_apply(self, func, x, p):
       return tf.cond(
         tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
                 tf.cast(p, tf.float32)),
         lambda: func(x),
         lambda: x)

这就是我导入图像数据集的方式:

train_ds = tf.data.Dataset.from_generator(path)

我想在我的 train_ds 上应用这种增强功能,请问,我该如何进行?

【问题讨论】:

    标签: python tensorflow2.0 tensorflow2.x


    【解决方案1】:

    首先,您应该使用 tf.keras.sequence 的子类创建一个自定义生成器,然后您可以实现 __getitem____len__ 方法。

    class CustomGenerator(tf.keras.utils.Sequence):
    
        def __init__(self, df, X_col, y_col,
                 batch_size,
                 input_size=(width, height, channels),
                 shuffle=True):
        
            self.df = df.copy()
            self.X_col = X_col
            self.y_col = y_col
            self.batch_size = batch_size
            self.input_size = input_size
        
            self.n = len(self.df)
            self.n_name = df[y_col['label']].nunique()
        
        def on_epoch_end(self):
            pass    
        
        def __getitem__(self, index):
            batches = self.df[index * self.batch_size:(index + 1) * 
                              self.batch_size]
            X, y = self.__get_data(batches)        
            return X, y
        
        def __len__(self):
            return self.n // self.batch_size
        
        def __get_output(self, label, num_classes):
            return tf.keras.utils.to_categorical(label, 
                                                 num_classes=num_classes)    
        
        def __get_input(self, path, target_size):
            # Load Image using PIL
            img = Image.open(self.base_path + path)
            img = np.array(img)
            
            # Your Augmentation
            img = CustomAugment(img)
            return img /255
    
    
        def __get_data(self, batches):
            # Generates data containing batch_size samples
    
            img_path_batch = batches[self.X_col['img']]
            label_batch = batches[self.y_col['label']]
    
            X_batch = np.asarray([self.__get_input(x, self.input_size)
                                  for x in img_path_batch])
            y_batch = np.asarray([self.__get_output(y)
                                  for y in label_batch])
    
            return X_batch, y_batch
    

    如您所见,您将在 __get_input 方法中扩充您的样本。

    要使用这个类:

    traingen = CustomDataGen(df, base_path=IMGS_DIR,
                         X_col={'img':'img'},
                         y_col={'label': 'label'},
                         max_label_len=11,
                         batch_size=16,
                         input_size=IMAGE_SIZE)
    

    注意:如果您需要在tf.data 上使用生成器,您应该像这样使用它:

    train_dataset = tf.data.Dataset.from_generator(lambda: traingen,                                               
                                                   output_types = (tf.float32, tf.int32),                                              
                                                   output_shapes = ([None, width, height, channels], [None, num_classes]))
    

    【讨论】:

    • NameError: name 'df' is not defined,请问为什么我会发现这个错误?
    • 您应该创建一个包含 2 列的 Pandas 数据框:“img”和“label”。 'img':img 的名称 'label':类的标签
    猜你喜欢
    • 2022-01-24
    • 2018-04-15
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-05-06
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多