【发布时间】:2020-12-08 23:27:39
【问题描述】:
我正在尝试训练一个相当复杂的模型,该模型使用多个冻结的预训练模型,并具有一个具有相当复杂的多任务损失函数的自定义训练循环。由于这些复杂性,我的计划是在子类模型中定义多个单独的 Keras 模型。我的设置一直存在问题,我已经能够将其简化为一个简单的示例来演示该问题。
下面的代码训练了一个名为MainModel 的简单模型,它使用Keras 模型子类化API,但它基本上只是一个Sequential([Conv1d(), Conv1d()]) 模型。当我在同一类中定义另一个模型self.aux_model 时,原始模型不再正确训练。在示例中,self.aux_model 在训练中没有任何作用,它只是被定义,从未使用过。具体来说,在每次训练迭代之后,权重值与迭代开始时相同。因此,即使梯度具有非零值,模型权重也不会更新。
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.callbacks import Callback
num_epochs = 5
steps_per_epoch = 100
audio_len = 16000
class WeightChecker:
"""Automated health checks for training Keras models."""
def __init__(self, model):
self.initial_model = model
self.var_names = [var.name for var in model.trainable_variables]
self.prev_weights = model.get_weights()
def check_epoch(self, model):
"""Checks to run at the end of an epoch"""
self.check_untrained_params(model)
def check_untrained_params(self, model):
"""Compare self.model.trainable_variables to self.prev_weights"""
passed = True
curr_weights = model.get_weights()
for curr_var, prev_var, var_name in zip(curr_weights, self.prev_weights, self.var_names):
eq = np.equal(curr_var, prev_var).all()
if eq:
passed = False
print(f"\nWarning: Variable {var_name} was not updated with training. "
f"Confirm that this layer is correctly "
f"connected to the computation graph.")
self.prev_weights = [w.copy() for w in curr_weights]
return passed
class WeightCheckerCallback(Callback):
"""Check model initialization and run training checks.
"""
def __init__(self):
super().__init__()
self.weight_check = None
def setup_weight_checker(
self,
model: tf.keras.Model = None):
"""Initialize the callback with an input_batch and targets."""
self.weight_check = WeightChecker(model)
def on_train_begin(self, logs=None):
if self.weight_check is None:
raise ValueError("setup_weight_checker() must be called to use WeightCheckerCallback.")
def on_epoch_end(self, epoch, logs=None):
self.weight_check.check_epoch(self.model)
class MainModel(tf.keras.Model):
"""Main Model."""
def __init__(self):
super().__init__()
self.feature_dim = 128
self.aux_model = self._set_aux_model()
self.map_model = tf.keras.Sequential([tf.keras.layers.Conv1D(
64, 3, padding='same'
),
tf.keras.layers.Conv1D(
1, 3, padding='same'
)])
def call(self, inputs, training=True):
output = self.map_model(inputs)
return output
def train_step(self, data):
mixed_audio = data[0]
clean_audio = data[1]
with tf.GradientTape() as tape:
decoded_audio = self.map_model(mixed_audio)
total_loss = tf.reduce_mean(tf.abs(decoded_audio - clean_audio))
grads = tape.gradient(total_loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
losses = {
'loss': total_loss,
}
return losses
@staticmethod
def _set_aux_model():
"""Set an auxiliary model."""
model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
model.build(input_shape=(None, 1))
model.trainable = False
return model
class TrainingTask:
"""A Keras model training task."""
def __init__(self):
self.model, self.stateful_model = self._set_model()
self.callbacks = [WeightCheckerCallback()]
@staticmethod
def _set_model():
model = MainModel()
# Build the model with fake data.
model.compile(optimizer='adam')
fake_data = np.random.randn(1,
audio_len,
1)
fake_data = fake_data.astype(np.float32)
model(fake_data, training=True)
return model, None
def fit(self):
"""Custom model fit method."""
try:
weight_checker_callback_index = [isinstance(cb, WeightCheckerCallback)
for cb in self.callbacks].index(True)
except ValueError:
weight_checker_callback_index = None
if weight_checker_callback_index is not None:
self.callbacks[weight_checker_callback_index].setup_weight_checker(
model=self.model
)
for callback in self.callbacks:
callback.set_model(self.model)
print("\nBegin training")
for callback in self.callbacks:
callback.on_train_begin()
for epoch in range(num_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
for batch in range(steps_per_epoch):
x, y = next(data_gen_batch())
for callback in self.callbacks:
callback.on_batch_begin(batch)
metrics = self.model.train_step([x, y])
batch_loss = np.mean(metrics.pop('loss'))
print(batch, epoch, batch_loss)
for callback in self.callbacks:
callback.on_batch_end(batch, metrics)
print(f'Epoch: {epoch}')
numeric_metrics = dict()
numeric_metrics['loss'] = batch_loss
for callback in self.callbacks:
callback.on_epoch_end(epoch, numeric_metrics)
def data_gen():
"""Generate random data for training."""
data = (np.random.random((audio_len, 1)), np.random.random((audio_len, 1)))
while True:
yield data
def data_gen_batch(batch_size=8):
"""Generate random data in batches for training."""
data = next(data_gen())
data_batch = (np.stack([data[0]] * batch_size, axis=0),
np.stack([data[1]] * batch_size, axis=0))
while True:
yield data_batch
if __name__ == '__main__':
task = TrainingTask()
task.fit()
WeightCheckerCallback 和 WeightChecker 类是我为说明问题而定义的回调,否则会导致静默失败。除了每个训练步骤的一些输出之外,代码还会产生以下关于map_model 层的警告,这些层应该正在更新(aux_model 只有一个Dense 层):
Warning: Variable main_model/sequential_1/conv1d/kernel:0 was not updated with training. Confirm that this layer is correctly connected to the computation graph.
Warning: Variable main_model/sequential_1/conv1d/bias:0 was not updated with training. Confirm that this layer is correctly connected to the computation graph.
但是,如果 aux_model 被注释掉,则不会出现警告,并且模型权重将按预期更新。
# self.aux_model = self._set_aux_model()
显然,在 tensorflow 中有几种方法可以让这个简单的 Sequential 模型正确训练,所以我不只是在寻找一种解决方法来让这个特定示例正常工作。相反,我希望有人可以根据所涉及的 Tensorflow 会话和图表来解释这个示例的情况,以及在将多个不同的 Keras 模型与子类 API 嵌套时避免冲突的最佳实践是什么.我的最终目标是使用类似的框架训练更复杂的模型系统。
【问题讨论】:
标签: python tensorflow keras