【问题标题】:Model with multiple outputs and custom loss function具有多个输出和自定义损失函数的模型
【发布时间】:2020-04-20 19:17:41
【问题描述】:

我正在尝试使用 keras 训练具有多个输出和自定义损失函数的模型,但我遇到了一些错误 tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over ``tf.Tensor`` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

很难调试它,因为我正在做model.compilemodel.fit。我认为这与在具有多个输出时应该如何定义模型有关,但我找不到关于此的好的文档。该指南指定了如何让具有多个输出的模型使用功能性 API,并为此提供了一个示例,但它没有说明在子类化 Model API 时自定义损失函数应该如何工作。我的代码如下:

class DeepEnsembles(Model):

    def __init__(self, **kwargs):
        super(DeepEnsembles, self).__init__()

        self.num_models = kwargs.get('num_models')
        model = kwargs.get('model')

        self.mean = [model(**dict(**kwargs)) for _ in range(self.num_models)]

        self.variance = [model(**dict(**kwargs)) for _ in range(self.num_models)]

    def call(self, inputs, training=None, mask=None):
        mean_predictions = []
        variance_predictions = []
        for idx in range(self.num_models):
            mean_predictions.append(self.mean[idx](inputs, training=training))
            variance_predictions.append(self.variance[idx](inputs, training=training))
        mean_stack = tf.stack(mean_predictions)
        variance_stack = tf.stack(variance_predictions)

        return mean_stack, variance_stack

MLP 如下:

class MLP(Model):
    def __init__(self, **kwargs):
        super(MLP, self).__init__()

        # Initialization parameters
        self.num_inputs = kwargs.get('num_inputs', 779)
        self.num_outputs = kwargs.get('num_outputs', 1)
        self.hidden_size = kwargs.get('hidden_size', 256)
        self.activation = kwargs.get('activation', 'relu')

        # Optional parameters
        self.p = kwargs.get('p', 0.05)

        self.model = tf.keras.Sequential([
            layers.Dense(self.hidden_size, activation=self.activation, input_shape=(self.num_inputs,)),
            layers.Dropout(self.p),
            layers.Dense(self.hidden_size, activation=self.activation),
            layers.Dropout(self.p),
            layers.Dense(self.num_outputs)
         ])

    def call(self, inputs, training=None, mask=None):
        output = self.model(inputs, training=training)
        return output

我正在尝试最小化自定义损失函数

class GaussianNLL(Loss):

    def __init__(self):
        super(GaussianNLL, self).__init__()

    def call(self, y_true, y_pred):

        mean, variance = y_pred
        variance = variance + 0.0001
        nll = (tf.math.log(variance) / 2 + ((y_true - mean) ** 2) / (2 * variance))
        nll = tf.math.reduce_mean(nll)
        return nll

最后,这是我尝试训练它的方式:

    ensembles_params = {'num_models': 5, 'model': MLP, 'p': 0}
    model = DeepEnsembles(**ensembles_params)
    loss_fn = GaussianNLL()
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    epochs = 10000

    model.compile(optimizer='adam',
                  loss=loss_fn,
                  metrics=['mse', 'mae'])
    history = model.fit(x_train, y_train,
                        batch_size=2048,
                        epochs=10000,
                        verbose=0,
                        validation_data=(x_val, y_val))

这会导致上述错误。任何指针?特别是,整个堆栈跟踪是

Traceback (most recent call last):
  File "/home/emilio/anaconda3/lib/python3.7/contextlib.py", line 130, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/ops/variable_scope.py", line 2803, in variable_creator_scope
    yield
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 235, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 593, in _process_training_inputs
    use_multiprocessing=use_multiprocessing)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 646, in _process_inputs
    x, y, sample_weight=sample_weights)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 2360, in _standardize_user_data
    self._compile_from_inputs(all_inputs, y_input, x, y)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 2618, in _compile_from_inputs
    experimental_run_tf_function=self._experimental_run_tf_function)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/training/tracking/base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 446, in compile
    self._compile_weights_loss_and_weighted_metrics()
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/training/tracking/base.py", line 457, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 1592, in _compile_weights_loss_and_weighted_metrics
    self.total_loss = self._prepare_total_loss(masks)
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 1652, in _prepare_total_loss
    per_sample_losses = loss_fn.call(y_true, y_pred)
  File "/home/emilio/fault_detection/tensorflow_code/tf_utils/loss.py", line 13, in call
    mean, variance = y_pred
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 539, in __iter__
    self._disallow_iteration()
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 535, in _disallow_iteration
    self._disallow_in_graph_mode("iterating over `tf.Tensor`")
  File "/home/emilio/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 515, in _disallow_in_graph_mode
    " this function with @tf.function.".format(task))
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.


所以它显然与损失函数有关。但是模型的前向传递输出了一个元组,我在损失函数中解包,所以我不知道为什么会出现这个问题。

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    通过快速测试,我认为我通过替换解决了问题:

            mean, variance = y_pred
            variance = variance + 0.0001
    

            mean = y_pred[0]
            variance = y_pred[1] + 0.0001
    

    解包y_pred(这是一个张量)调用方法Tensor.__iter__显然会产生错误,而我认为方法Tensor.__getitem__不会...

    我还没开始学习,我认为我当前的虚拟 x_train 和 y_train 的形状不完全正确。如果您发现以后再次出现此问题,我会尝试调查。

    编辑:

    我设法通过使用使您的代码运行

    x_train = np.random.random((10000, 779))
    y_train = np.random.random ((10000, 1))
    

    通过将方法DeepEnsembles.call的最后一行更改为

            return tf.stack([mean_stack, variance_stack])
    

    并通过注释掉指标(这是必要的,因为 y_true 和 y_pred 的大小预计会不同,因此您可能需要定义自己的 mse 和 mae 版本以用作指标):

    model.compile(optimizer='adam',
                  loss=loss_fn,
                  # metrics=['mse', 'mae']
    )
    

    我相信它非常接近您的预期。

    不返回元组的原因是 tensorflow 会将元组的每个元素解释为网络的输出,并将损失独立应用于每个元素。

    您可以通过保留旧版本的DeepEnsembles.call 来测试它,而不是使用

    y_train_1 = np.random.random ((10000, 1))
    y_train_2 = np.random.random ((10000, 1))
    y_train = [y_train_1, y_train_2]
    

    它会执行,会有 10 个 MLP,但是 MLP_1/2 会学习 y_train_1 的均值和方差,MLP_6/7 会学习 y_train_2 的均值和 var,其他所有 MLP 什么都不会学习。

    【讨论】:

    • 感谢您的回答。输入应该有 779 列/特征,输出 1 列/特征。您的解决方案有效,但我现在收到以下错误:Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), for inputs ['output_1', 'output_2'] but instead got the following list of 1 arrays:...
    • 你能给我x_train.shapey_train.shape的确切值吗?
    • 第一个维度无关紧要。所以它是 [?, 779] 和 [?, 1] 哪里?是你想要的任何东西。
    • 当然,谢谢,我设法让它工作(获取 10000 个数据点,有多个批次),请参阅我上面的编辑。
    • 你怎么知道 MLP 的其余部分不会学到任何东西?为什么是这样?我希望他们学习。
    猜你喜欢
    • 2019-05-19
    • 2019-07-06
    • 1970-01-01
    • 2019-08-04
    • 1970-01-01
    • 2020-01-21
    • 2019-01-11
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多