【问题标题】:How to use multiple models in Keras model subclass API如何在 Keras 模型子类 API 中使用多个模型
【发布时间】:2020-12-08 23:27:39
【问题描述】:

我正在尝试训练一个相当复杂的模型,该模型使用多个冻结的预训练模型,并具有一个具有相当复杂的多任务损失函数的自定义训练循环。由于这些复杂性,我的计划是在子类模型中定义多个单独的 Keras 模型。我的设置一直存在问题,我已经能够将其简化为一个简单的示例来演示该问题。

下面的代码训练了一个名为MainModel 的简单模型,它使用Keras 模型子类化API,但它基本上只是一个Sequential([Conv1d(), Conv1d()]) 模型。当我在同一类中定义另一个模型self.aux_model 时,原始模型不再正确训练。在示例中,self.aux_model 在训练中没有任何作用,它只是被定义,从未使用过。具体来说,在每次训练迭代之后,权重值与迭代开始时相同。因此,即使梯度具有非零值,模型权重也不会更新。

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.callbacks import Callback

num_epochs = 5
steps_per_epoch = 100
audio_len = 16000


class WeightChecker:
    """Automated health checks for training Keras models."""
    def __init__(self, model):
        self.initial_model = model
        self.var_names = [var.name for var in model.trainable_variables]
        self.prev_weights = model.get_weights()

    def check_epoch(self, model):
        """Checks to run at the end of an epoch"""
        self.check_untrained_params(model)

    def check_untrained_params(self, model):
        """Compare self.model.trainable_variables to self.prev_weights"""
        passed = True
        curr_weights = model.get_weights()
        for curr_var, prev_var, var_name in zip(curr_weights, self.prev_weights, self.var_names):
            eq = np.equal(curr_var, prev_var).all()
            if eq:
                passed = False
                print(f"\nWarning: Variable {var_name} was not updated with training. "
                      f"Confirm that this layer is correctly "
                      f"connected to the computation graph.")
        self.prev_weights = [w.copy() for w in curr_weights]
        return passed


class WeightCheckerCallback(Callback):
    """Check model initialization and run training checks.
    """
    def __init__(self):
        super().__init__()
        self.weight_check = None

    def setup_weight_checker(
            self,
            model: tf.keras.Model = None):
        """Initialize the callback with an input_batch and targets."""
        self.weight_check = WeightChecker(model)

    def on_train_begin(self, logs=None):
        if self.weight_check is None:
            raise ValueError("setup_weight_checker() must be called to use WeightCheckerCallback.")

    def on_epoch_end(self, epoch, logs=None):
        self.weight_check.check_epoch(self.model)


class MainModel(tf.keras.Model):
    """Main Model."""
    def __init__(self):
        super().__init__()
        self.feature_dim = 128
        self.aux_model = self._set_aux_model()

        self.map_model = tf.keras.Sequential([tf.keras.layers.Conv1D(
                64, 3, padding='same'
            ),
            tf.keras.layers.Conv1D(
                1, 3, padding='same'
            )])

    def call(self, inputs, training=True):
        output = self.map_model(inputs)
        return output

    def train_step(self, data):
        mixed_audio = data[0]
        clean_audio = data[1]

        with tf.GradientTape() as tape:
            decoded_audio = self.map_model(mixed_audio)
            total_loss = tf.reduce_mean(tf.abs(decoded_audio - clean_audio))

        grads = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        losses = {
            'loss': total_loss,
        }
        return losses

    @staticmethod
    def _set_aux_model():
        """Set an auxiliary model."""
        model = tf.keras.Sequential([tf.keras.layers.Dense(1)])
        model.build(input_shape=(None, 1))

        model.trainable = False
        return model


class TrainingTask:
    """A Keras model training task."""
    def __init__(self):
        self.model, self.stateful_model = self._set_model()
        self.callbacks = [WeightCheckerCallback()]

    @staticmethod
    def _set_model():
        model = MainModel()

        # Build the model with fake data.
        model.compile(optimizer='adam')
        fake_data = np.random.randn(1,
                                    audio_len,
                                    1)
        fake_data = fake_data.astype(np.float32)
        model(fake_data, training=True)

        return model, None

    def fit(self):
        """Custom model fit method."""
        try:
            weight_checker_callback_index = [isinstance(cb, WeightCheckerCallback)
                                             for cb in self.callbacks].index(True)
        except ValueError:
            weight_checker_callback_index = None
        if weight_checker_callback_index is not None:
            self.callbacks[weight_checker_callback_index].setup_weight_checker(
                model=self.model
            )

        for callback in self.callbacks:
            callback.set_model(self.model)

        print("\nBegin training")

        for callback in self.callbacks:
            callback.on_train_begin()

        for epoch in range(num_epochs):

            for callback in self.callbacks:
                callback.on_epoch_begin(epoch)

            for batch in range(steps_per_epoch):
                x, y = next(data_gen_batch())

                for callback in self.callbacks:
                    callback.on_batch_begin(batch)

                metrics = self.model.train_step([x, y])
                batch_loss = np.mean(metrics.pop('loss'))

                print(batch, epoch, batch_loss)

                for callback in self.callbacks:
                    callback.on_batch_end(batch, metrics)

            print(f'Epoch: {epoch}')

            numeric_metrics = dict()
            numeric_metrics['loss'] = batch_loss
            for callback in self.callbacks:
                callback.on_epoch_end(epoch, numeric_metrics)


def data_gen():
    """Generate random data for training."""
    data = (np.random.random((audio_len, 1)), np.random.random((audio_len, 1)))
    while True:
        yield data


def data_gen_batch(batch_size=8):
    """Generate random data in batches for training."""
    data = next(data_gen())
    data_batch = (np.stack([data[0]] * batch_size, axis=0),
                  np.stack([data[1]] * batch_size, axis=0))
    while True:
        yield data_batch


if __name__ == '__main__':
    task = TrainingTask()
    task.fit()

WeightCheckerCallbackWeightChecker 类是我为说明问题而定义的回调,否则会导致静默失败。除了每个训练步骤的一些输出之外,代码还会产生以下关于map_model 层的警告,这些层应该正在更新(aux_model 只有一个Dense 层):

Warning: Variable main_model/sequential_1/conv1d/kernel:0 was not updated with training. Confirm that this layer is correctly connected to the computation graph.
Warning: Variable main_model/sequential_1/conv1d/bias:0 was not updated with training. Confirm that this layer is correctly connected to the computation graph.

但是,如果 aux_model 被注释掉,则不会出现警告,并且模型权重将按预期更新。

        # self.aux_model = self._set_aux_model()

显然,在 tensorflow 中有几种方法可以让这个简单的 Sequential 模型正确训练,所以我不只是在寻找一种解决方法来让这个特定示例正常工作。相反,我希望有人可以根据所涉及的 Tensorflow 会话和图表来解释这个示例的情况,以及在将多个不同的 Keras 模型与子类 API 嵌套时避免冲突的最佳实践是什么.我的最终目标是使用类似的框架训练更复杂的模型系统。

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    结果证明这是我使用 WeightChecker 类检查图层的方式的问题。 model.get_weights() 返回所有权重,而不仅仅是可训练的权重。因此,当我们在for 循环中使用zip() 时,我们会将不同长度的列表压缩在一起,这会导致未更新的层的名称被误报。该错误可以通过使用以下而不是model.get_weights()来解决:

    self.prev_weights = [var.numpy() for var in model.trainable_variables]
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-08-09
      • 1970-01-01
      • 1970-01-01
      • 2017-06-02
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-08-30
      相关资源
      最近更新 更多