【问题标题】:How can I get the number of trainable parameters of a model in Keras?如何在 Keras 中获取模型的可训练参数数量?
【发布时间】:2017-12-16 05:21:25
【问题描述】:

我在所有层中设置trainable=False,通过Model API 实现,但我想验证这是否有效。 model.count_params() 返回参数的总数,但是除了查看model.summary() 的最后几行之外,有什么方法可以获得可训练参数的总数?

【问题讨论】:

  • AFAIK,没有比model.summary()更好的方法了

标签: python keras


【解决方案1】:
from keras import backend as K

trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
    np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

上面的sn-p可以在layer_utils.print_summary()定义的末尾发现,summary()正在调用它。


编辑:更新版本的 Keras 有一个辅助函数 count_params() 用于此目的:

from keras.utils.layer_utils import count_params

trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)

【讨论】:

  • 请注意,批标准化参数不计入trainable,即使它们是在训练期间发生变化的变量。
  • 如果您使用的是 TensorFlow 2,count_params 函数现在位于 from tensorflow.python.keras.utils.layer_utils import count_params
  • 或者只是 sum(count_params(layer) for layer in model.trainable_weights) 以获得更多可读性...我不清楚对 set 的调用完成了什么。
【解决方案2】:

对于 TensorFlow 2.0

import tensorflow.keras.backend as K

trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

【讨论】:

    【解决方案3】:

    对于 tensorflow.keras,这对我有用。它来自 layer_utils.py 中函数 print_layer_summary_with_connections() 的 tensorflow github 代码

    import numpy as np
    from tensorflow.python.util import object_identity
    
    def count_params(weights):
        return int(sum(np.prod(p.shape.as_list())
          for p in object_identity.ObjectIdentitySet(weights)))
    
    if hasattr(model, '_collected_trainable_weights'):
        trainable_count = count_params(model._collected_trainable_weights)
    else:
        trainable_count = count_params(model.trainable_weights)
    
    print (trainable_count)
    

    【讨论】:

      【解决方案4】:

      另一种计算可训练参数的方法是:

      model.count_params()
      

      【讨论】:

      • 不,这是参数的总数,可训练的和不可训练的。
      猜你喜欢
      • 2019-05-28
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-11-29
      • 1970-01-01
      • 2019-02-01
      • 2021-03-19
      • 2021-09-27
      相关资源
      最近更新 更多