让我们看看您的模型与顺序实现和功能 API 实现:
以下是一些导入:
import tensorflow as tf
from tensorflow.keras.layers import Lambda,Conv2D, Activation, Input
from tensorflow.keras import Model, Sequential
这是您使用顺序模型的实现:
model = Sequential()
model.add(Conv2D(16, (5, 5), input_shape=(256, 256, 1)))
x = model.layers[0].output
model.add(Lambda(lambda x: tf.abs(x)))
model.add(Activation(activation='tanh'))
model.summary()
总结输出:
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_6 (Conv2D) (None, 252, 252, 16) 416
_________________________________________________________________
lambda_6 (Lambda) (None, 252, 252, 16) 0
_________________________________________________________________
activation_5 (Activation) (None, 252, 252, 16) 0
=================================================================
Total params: 416
Trainable params: 416
Non-trainable params: 0
_________________________________________________________________
现在使用功能 API 实现:
首先,定义你的函数:
def arbitrary_functionality(tensor):
return tf.abs(tensor)
还有:
input_layer = Input(shape=(256, 256, 1))
conv1 = Conv2D(16, (5, 5))(input_layer)
lambda_layer = Lambda(arbitrary_functionality)(conv1)
output_layer = Activation(activation='tanh')(lambda_layer)
model_2 = Model(inputs=input_layer, outputs=output_layer)
model_2 .summary()
总结输出:
Model: "model_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_7 (InputLayer) [(None, 256, 256, 1)] 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 252, 252, 16) 416
_________________________________________________________________
lambda_9 (Lambda) (None, 252, 252, 16) 0
_________________________________________________________________
activation_8 (Activation) (None, 252, 252, 16) 0
=================================================================
Total params: 416
Trainable params: 416
Non-trainable params: 0
_________________________________________________________________
注意:根据 TensorFlow 文档,更好的方法是继承 Layer 类。查看示例here。