【发布时间】:2020-10-14 04:15:41
【问题描述】:
我希望在 Keras 中将遮罩应用到 Conv2D 层的内核。我在理解内核形状方面有点困难。
对于kernel_size = 3,filters = 1,内核的形状是(3, 3, 4, 1) => (kernel_size, kernel_size, ???, filters)
内核中的第 3 维代表什么?
如何获取 NxN 掩码并将其与每个内核过滤器相乘?
这是我到目前为止的代码。我不确定它是否会按预期工作,因为我不完全了解内核形状。
class MaskedConv2D(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
super(MaskedConv2D, self).__init__()
self.conv2d = Conv2D(*args, **kwargs)
def build(self, input_shape):
self.conv2d.build(input_shape[0])
self._convolution_op = self.conv2d._convolution_op
def masked_convolution_op(self, filters, kernel, mask):
m = K.expand_dims(K.expand_dims(mask[0, ...], axis=2), axis=3) # (3, 3) => (3, 3, 1, 1)
m = K.tile(m, (1, 1, kernel.shape[2], kernel.shape[3])) # (3, 3, 1, 1) => (3, 3, 4, 1)
return self._convolution_op(filters, tf.math.multiply(kernel, m))
def call(self, inputs):
x, mask = inputs
self.conv2d._convolution_op = functools.partial(self.masked_convolution_op, mask=mask)
return self.conv2d.call(x)
【问题讨论】:
标签: python tensorflow machine-learning keras convolution