如果您稍微有创意地思考,则根本不需要噪声输入层。
我们可以使用来自 tf.keras 的 GaussianNoise(),但它需要一个输入张量,从技术上讲,我们应该传递一个 ones 向量。我们可以让中间 VGG 输出特征乘以密集层的零核,然后我们可以向它添加 ones 偏差,这样我们就得到了占位符 ones 向量,它将被传递给 GaussianNoise。
现在,您可以忘记任何复杂的数据生成器,只需使用常规数据生成器或直接使用带有 fit 的 numpy 数组。
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow as tf
ip = Input((224,224,3))
base = VGG16((224,224,3))(ip)
# passing vgg features to a zero vector, making everything zero and then adding bias ones to make the output is always 1
dense_ones = Dense(1000, activation='linear', kernel_initializer = tf.keras.initializers.Zeros(), bias_initializer = tf.keras.initializers.Ones())(base)
gaussian = GaussianNoise(0.4)(dense_ones)
concat = Concatenate()([base, gaussian])
learn_feature = Dense(128, activation = 'relu')(concat) # change this part based on your needs
classification = Dense(2, activation = 'sigmoid')(learn_feature)
model = Model(ip, classification)
Model: "model_5"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_19 (InputLayer) [(None, 224, 224, 3) 0
__________________________________________________________________________________________________
vgg16 (Model) (None, 1000) 138357544 input_19[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1000) 1001000 vgg16[1][0]
__________________________________________________________________________________________________
gaussian_noise_5 (GaussianNoise (None, 1000) 0 dense_2[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 2000) 0 vgg16[1][0]
gaussian_noise_5[0][0]
==================================================================================================
Total params: 139,358,544
Trainable params: 139,358,544
Non-trainable params: 0
___________________________________