【问题标题】:tensorflow warning - Found untraced functions such as lstm_cell_6_layer_call_and_return_conditional_lossestensorflow 警告 - 发现未跟踪的函数,例如 lstm_cell_6_layer_call_and_return_conditional_losses
【发布时间】:2021-01-13 07:49:22
【问题描述】:

我使用的是 tensorflow2.4,是 tensorflow 的新手

这是代码

model = Sequential()
model.add(LSTM(32, input_shape=(X_train.shape[1:])))
model.add(Dropout(0.2))
model.add(Dense(1, activation='linear'))

model.compile(optimizer='rmsprop', loss='mean_absolute_error', metrics='mae')
model.summary()

save_weights_at = 'basic_lstm_model'
save_best = ModelCheckpoint(save_weights_at, monitor='val_loss', verbose=0,
                        save_best_only=True, save_weights_only=False, mode='min',
                        period=1)
history = model.fit(x=X_train, y=y_train, batch_size=16, epochs=20,
         verbose=1, callbacks=[save_best], validation_data=(X_val, y_val),
         shuffle=True)

在某些时期,收到了以下警告:

你知道我为什么会收到这个警告吗?

【问题讨论】:

  • 有同样的问题,也使用 LSTM 层。你解决了吗?
  • 还没有~在我的情况下并没有影响输出......
  • @Cherry Wu,我尝试在没有ModelCheckpoint 的情况下执行,但没有显示任何警告。看来这是TF 2.4 中的一个未解决问题,可以跟踪Saving model in TF 2.4。谢谢!
  • 感谢您让我知道@TFer2!是的,就我而言,必须使用ModelCheckpoint

标签: tensorflow warnings lstm


【解决方案1】:

我认为可以安全地忽略此警告,因为即使在 tensorflow 给出的 tutorial 中也可以找到相同的警告。在保存图形 NN 等自定义模型时,我经常看到此警告。只要您不想访问那些不可调用的函数,您应该可以继续使用。

但是,如果您对这一大段文本感到恼火,您可以通过在代码顶部添加以下内容来抑制此警告。

import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)

【讨论】:

  • 如果您确实想访问这些 fns 怎么办?
【解决方案2】:

以 H5 格式保存模型似乎对我有用。

model.save(filepath, save_format="h5")

这里是如何将 H5 与模型检查点一起使用(我没有对此进行过广泛的测试,请注意!)

from tensorflow.keras.callbacks import ModelCheckpoint

class ModelCheckpointH5(ModelCheckpoint):
    # There is a bug saving models in TF 2.4
    # https://github.com/tensorflow/tensorflow/issues/47479
    # This forces the h5 format for saving
    def __init__(self,
               filepath,
               monitor='val_loss',
               verbose=0,
               save_best_only=False,
               save_weights_only=False,
               mode='auto',
               save_freq='epoch',
               options=None,
               **kwargs):
        super(ModelCheckpointH5, self).__init__(filepath,
               monitor='val_loss',
               verbose=0,
               save_best_only=False,
               save_weights_only=False,
               mode='auto',
               save_freq='epoch',
               options=None,
               **kwargs)
    def _save_model(self, epoch, logs):
        from tensorflow.python.keras.utils import tf_utils
   
        logs = logs or {}

        if isinstance(self.save_freq,
                      int) or self.epochs_since_last_save >= self.period:
          # Block only when saving interval is reached.
          logs = tf_utils.to_numpy_or_python_type(logs)
          self.epochs_since_last_save = 0
          filepath = self._get_file_path(epoch, logs)

          try:
            if self.save_best_only:
              current = logs.get(self.monitor)
              if current is None:
                logging.warning('Can save best model only with %s available, '
                                'skipping.', self.monitor)
              else:
                if self.monitor_op(current, self.best):
                  if self.verbose > 0:
                    print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
                          ' saving model to %s' % (epoch + 1, self.monitor,
                                                   self.best, current, filepath))
                  self.best = current
                  if self.save_weights_only:
                    self.model.save_weights(
                        filepath, overwrite=True, options=self._options)
                  else:
                    self.model.save(filepath, overwrite=True, options=self._options,save_format="h5") # NK edited here
                else:
                  if self.verbose > 0:
                    print('\nEpoch %05d: %s did not improve from %0.5f' %
                          (epoch + 1, self.monitor, self.best))
            else:
              if self.verbose > 0:
                print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
              if self.save_weights_only:
                self.model.save_weights(
                    filepath, overwrite=True, options=self._options)
              else:
                self.model.save(filepath, overwrite=True, options=self._options,save_format="h5") # NK edited here

            self._maybe_remove_file()
          except IOError as e:
            # `e.errno` appears to be `None` so checking the content of `e.args[0]`.
            if 'is a directory' in six.ensure_str(e.args[0]).lower():
              raise IOError('Please specify a non-directory filepath for '
                            'ModelCheckpoint. Filepath used is an existing '
                            'directory: {}'.format(filepath))
            # Re-throw the error for any other causes.
            raise 

【讨论】:

    【解决方案3】:

    尝试将扩展名附加到文件中。

    save_weights_at = 'basic_lstm_model'
    

    为:

    save_weights_at = 'basic_lstm_model.h5'
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-12-11
      • 2011-09-18
      • 2021-03-21
      • 1970-01-01
      • 2021-05-31
      • 2021-06-21
      相关资源
      最近更新 更多