【发布时间】:2020-07-31 19:02:13
【问题描述】:
我正在尝试为分布式学习实现自定义正则化函数,以实现等式中的惩罚函数
我将上述函数实现为逐层正则化器,但它会引发错误。期待社区的帮助
@tf.keras.utils.register_keras_serializable(package='Custom', name='esgd')
def esgd(w, wt, mu):
delta = tf.math.square(tf.norm(w-wt))
rl = (mu/2)*delta
return rl
def model(w, wt, mu):
model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(32,(3,3), padding='same', activation='relu',input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64,(3,3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation='relu', kernel_initializer='ones',kernel_regularizer=esgd(w[0][7],wt[0][7],mu)
),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Dense(10, activation='softmax')
])
return model
----- 错误-------
---> 59 model = init_model(w, wt, mu)
60
61 # model.set_weights(wei[0])
<ipython-input-5-e0796dd9fa55> in init_model(w, wt, mu)
11 tf.keras.layers.Dropout(0.25),
12 tf.keras.layers.Flatten(),
---> 13 tf.keras.layers.Dense(128,activation='relu', kernel_initializer='ones',kernel_regularizer=esgd(w[0][7],wt[0][7],mu)
14 ),
15 tf.keras.layers.Dropout(0.25),
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py in __init__(self, units, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, kernel_constraint, bias_constraint, **kwargs)
1137 self.kernel_initializer = initializers.get(kernel_initializer)
1138 self.bias_initializer = initializers.get(bias_initializer)
-> 1139 self.kernel_regularizer = regularizers.get(kernel_regularizer)
1140 self.bias_regularizer = regularizers.get(bias_regularizer)
1141 self.kernel_constraint = constraints.get(kernel_constraint)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/regularizers.py in get(identifier)
313 return identifier
314 else:
--> 315 raise ValueError('Could not interpret regularizer identifier:', identifier)
ValueError: ('Could not interpret regularizer identifier:', <tf.Tensor: shape=(), dtype=float32, numpy=0.00068962533>)
【问题讨论】:
标签: keras tensorflow2.0