【问题标题】:TF 2.2: How to compute custom metric when using MirroredStrategyTF 2.2:使用 MirroredStrategy 时如何计算自定义指标
【发布时间】:2021-06-14 00:18:36
【问题描述】:

我正在训练的 tf.keras.Model 具有以下主要性能指标:

  • 逃逸率:(#samples with predict label 0 AND true label 1) / (#samples with true label 1)
  • 错误调用率:(#samples with predict label 1 AND true label 0) / (#samples with true label 0)

目标逃逸率是预定义的,这意味着必须适当设置决策阈值。为了计算由此产生的误报率,我想在下面的伪代码行的某处实现一个自定义指标:

# separate predicted probabilities by their true label
all_ok_probabilities = all_probabilities.filter(true_label == 0)
all_nok_probabilities = all_probabilities.filter(true_label == 1)

# sort NOK samples
sorted_nok_probabilities = all_nok_probabilities.sort(ascending)

# determine decision threshold
threshold_idx = round(target_escape_rate * num_samples) - 1
threshold = sorted_nok_probabilities[threshold_idx]

# calculate false call rate
false_calls = count(all_ok_probabilities > threshold)
false_call_rate = false_calls / num_ok_samples

我的问题是,在 MirroredStrategy 环境中,tf.keras 自动将度量计算分布到所有副本中,每个副本在每次更新时获取 (batch_size / n_replicas) 个样本,最后对结果求和。然而,我的算法只有在所有标签和预测结合时才能正常工作(最终的求和可能通过除以副本数来克服)。

我的想法是将我度量标准的update_state() 方法中的所有y_truey_pred 连接成序列,并在result() 中运行评估。然而,第一步似乎已经不可能了; tf.Variable 只为数字标量提供合适的聚合方法,而不是为序列: tf.VariableAggregation.ONLY_FIRST_REPLICA 让我失去从第 2 个副本到第 n 个副本的所有数据,SUM 默默地锁定了 fit() 调用,MEAN 在我的情况下没有任何意义应用程序(也可能挂起)。

我已经尝试在 MirroredStrategy 范围之外实例化指标,但 tf.keras.Model.compile() 不接受。

任何提示/想法?

P.S.:如果您需要一个最小的代码示例,请告诉我,我正在努力。 :)

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    通过将其实现为回调而不是度量来解决自己的问题。我在没有“validation_data”的情况下运行fit(),而是在回调中计算所有验证集指标。这避免了两个冗余的验证集预测。

    为了将生成的度量值注入到训练过程中,我使用了来自Access variables of caller function in Python 的相当老套的方法。

    class ValidationCallback(tf.keras.callbacks.Callback):
        """helper class to calculate validation set metrics after each epoch"""
    
        def __init__(self, val_data, escape_rate, **kwargs):
            # call parent constructor
            super(ValidationCallback, self).__init__(**kwargs)
    
            # save parameters
            self.val_data = val_data
            self.escape_rate = escape_rate
    
            # declare batch_size - we will get that later
            self.batch_size = 0
    
        def on_epoch_end(self, epoch, logs=None):
            # initialize empty arrays
            y_pred = np.empty((0,2))
            y_true = np.empty(0)
    
            # iterate over validation set batches
            for batch in self.val_data:
                # save batch size, if not yet done
                if self.batch_size == 0:
                    self.batch_size = batch[1].shape[0]
    
                # concat all batch labels & predictions
                # need to do predict()[0] due to several model outputs
                y_pred = np.concatenate([y_pred, self.model.predict(batch[0])[0]], axis=0)
                y_true = np.concatenate([y_true, batch[1]], axis=0)
    
            # calculate classical accuracy for threshold 0.5
            acc = ((y_pred[:, 1] >= 0.5) == y_true).sum() / y_true.shape[0]
    
            # calculate cross-entropy loss
            cce = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.SUM)
            loss = cce(y_true, y_pred).numpy() / self.batch_size
    
            # caculate false call rate
            y_pred_nok = np.sort(y_pred[y_true == 1, 1])
            idx = int(np.round(self.escape_rate * y_pred_nok.shape[0]))
            threshold = y_pred_nok[idx]
            false_calls = y_pred[(y_true == 0) & (y_pred[:, 1] >= threshold), 1].shape[0]
            fcr = false_calls / y_true[y_true == 0].shape[0]
    
            # add metrics to 'logs' dict of our caller (tf.keras.callbacks.CallbackList.on_epoch_end()),
            # so that they become available to following callbacks
            for f in inspect.stack():
                if 'logs' in f[0].f_locals:
                    f[0].f_locals['logs'].update({'val_accuracy': acc,
                                                  'val_loss': loss,
                                                  'val_false_call_rate': fcr})
                    return
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-05-24
      • 2017-11-14
      • 2020-02-26
      相关资源
      最近更新 更多