【问题标题】:Batch Norm - Extract Running Mean & Running Variance in TensorFlowBatch Norm - 在 TensorFlow 中提取运行均值和运行方差
【发布时间】:2018-08-24 22:18:00
【问题描述】:

我正在尝试查看通过 GCMLE(saved_model.pbassets/*variables/*)导出的经过训练的 tensorflow 模型的运行均值和运行方差。这些值保存在图中的什么位置?我可以从tf.GraphKeys.TRAINABLE_VARIABLES 访问 gamma/beta 值,但我无法在任何tf.GraphKeys.MODEL_VARIABLES 中找到运行均值和运行方差。运行均值和运行方差是否存储在其他位置?

我知道在测试时(即Modes.EVAL),运行均值和运行方差用于对传入数据进行归一化,然后使用 gamma 和 beta 对归一化数据进行缩放和移位。我试图查看推理时需要的所有变量,但找不到运行均值和运行方差。这些是否仅在测试时使用而不是在推理时使用(Modes.PREDICT)?如果是这样,这就解释了为什么我在导出的模型中找不到它们,但我希望它们在那里。

基于tf.GraphKeys,我尝试过其他类似tf.GraphKeys.MOVING_AVERAGE_VARIABLES 的东西,但它们也是空的。我还在batch_normalization文档中看到了这一行“注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作放在tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。”所以我然后尝试从我保存的模型中查看tf.GraphKeys.UPDATE_OPS,它们包含一个分配操作batch_normalization/AssignMovingAvg:0,但仍然不清楚我将从哪里获得值。

【问题讨论】:

    标签: python tensorflow google-cloud-ml batch-normalization


    【解决方案1】:

    移动均值和移动方差似乎存储在tf.GraphKeys.GLOBAL_VARIABLES 中,看起来MODEL_VARIABLES 中没有显示的原因是因为您需要使用tf.contrib.framework.local_variable

    【讨论】:

      【解决方案2】:

      除了#reese0106 的回答之外,
      如果您想取出 BatchNorm 的moving_mean、moving_variance,
      您可以使用如下名称对它们进行索引。

      vars = tf.global_variables() # shows every variable being used.
      vars_moving_mean_variance = []
      for var in vars:
          if ("moving_mean" in var.name) or ("moving_variance" in var.name):
              vars_moving_mean_variance.append(var)
      
      print(vars_moving_mean_variance)
      


      p.s.感谢您的问题和答案。我也解决了自己的问题。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2016-02-08
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2016-09-27
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多