【发布时间】:2021-06-04 08:56:33
【问题描述】:
我有这个扩充代码:
class CustomAugment(object):
def __call__(self, sample):
sample = self._random_apply(tf.image.flip_left_right, sample, p=0.5)
sample = self._random_apply(self._color_jitter, sample, p=0.8)
sample = self._random_apply(self._color_drop, sample, p=0.2)
return sample
def _color_jitter(self, x, s=1):
x = tf.image.random_brightness(x, max_delta=0.8*s)
x = tf.image.random_contrast(x, lower=1-0.8*s, upper=1+0.8*s)
x = tf.image.random_saturation(x, lower=1-0.8*s, upper=1+0.8*s)
x = tf.image.random_hue(x, max_delta=0.2*s)
x = tf.clip_by_value(x, 0, 1)
return x
def _color_drop(self, x):
x = tf.image.rgb_to_grayscale(x)
x = tf.tile(x, [1, 1, 1, 3])
return x
def _random_apply(self, func, x, p):
return tf.cond(
tf.less(tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
tf.cast(p, tf.float32)),
lambda: func(x),
lambda: x)
这就是我导入图像数据集的方式:
train_ds = tf.data.Dataset.from_generator(path)
我想在我的 train_ds 上应用这种增强功能,请问,我该如何进行?
【问题讨论】:
标签: python tensorflow2.0 tensorflow2.x