【发布时间】:2020-03-07 23:11:30
【问题描述】:
我有以 TFRecord 格式存储的图像,我希望在 Tensorflow 中高效地执行 progressive sprinkles augmentation。
下面是我想出的实现:
class Cutout:
def __init__(self, num_holes, side_length):
self.n = num_holes
self.length = side_length
def __call__(self, image):
img_shape = tf.shape(image)
i = tf.range(img_shape[0])
j = tf.range(img_shape[1])
masking_fn = Cutout._mask_out(image, img_shape, i, j, self.length)
idx = tf.constant(0, dtype=tf.int32)
image, idx = tf.while_loop(
cond=lambda x, ii: tf.less(ii, self.n),
body=masking_fn,
loop_vars=[image, idx]
)
return image
@staticmethod
def _mask_out(image, img_shape, row_range, col_range, hole_length):
shape = tf.shape(image)
rows = shape[0]
cols = shape[1]
channels = shape[2]
def _create_hole(image, idx):
# Masks rows and columns to be replaced
r = tf.random_uniform([], minval=0, maxval=rows, dtype=tf.int32)
c = tf.random_uniform([], minval=0, maxval=cols, dtype=tf.int32)
r1 = tf.clip_by_value(r - hole_length // 2, 0, rows)
r2 = tf.clip_by_value(r + hole_length // 2, 0, rows)
c1 = tf.clip_by_value(c - hole_length // 2, 0, cols)
c2 = tf.clip_by_value(c + hole_length // 2, 0, cols)
row_mask = (r1 <= row_range) & (row_range < r2)
col_mask = (c1 <= col_range) & (col_range < c2)
zeros = tf.zeros(shape)
# Full mask of replaced elements
mask = row_mask[:, tf.newaxis] & col_mask
# Select elements from flattened arrays
img_flat = tf.reshape(image, [-1, channels])
zeros_flat = tf.reshape(zeros, [-1, channels])
mask_flat = tf.reshape(mask, [-1])
result_flat = tf.where(mask_flat, zeros_flat, img_flat)
# Reshape back
result = tf.reshape(result_flat, img_shape)
return [result, idx + 1]
return _create_hole
它可以工作,但是,这种实现非常低效。在我的机器上获取一批 32 张图像(增强设置为 250 个孔,边长为 5)大约需要 90 秒,而在没有应用任何增强的情况下加载它们时不到一秒。
我尝试使用 numpy 创建蒙版,因为它更有效,但如果事先不知道图像的形状,它就无法工作。对tf.shape(image) 的调用将包含运行时的形状信息(当图形在会话中执行时),但是,numpy 需要预先使用这些值来创建张量。
【问题讨论】:
-
答案是否对您的情况有所帮助?如果是,请将其标记为已接受。如果没有,请告诉我它的不足之处,我会尝试更新它。
标签: python tensorflow deep-learning