【发布时间】:2017-12-16 05:21:25
【问题描述】:
我在所有层中设置trainable=False,通过Model API 实现,但我想验证这是否有效。 model.count_params() 返回参数的总数,但是除了查看model.summary() 的最后几行之外,有什么方法可以获得可训练参数的总数?
【问题讨论】:
-
AFAIK,没有比
model.summary()更好的方法了
我在所有层中设置trainable=False,通过Model API 实现,但我想验证这是否有效。 model.count_params() 返回参数的总数,但是除了查看model.summary() 的最后几行之外,有什么方法可以获得可训练参数的总数?
【问题讨论】:
model.summary()更好的方法了
from keras import backend as K
trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
上面的sn-p可以在layer_utils.print_summary()定义的末尾发现,summary()正在调用它。
编辑:更新版本的 Keras 有一个辅助函数 count_params() 用于此目的:
from keras.utils.layer_utils import count_params
trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)
【讨论】:
trainable,即使它们是在训练期间发生变化的变量。
count_params 函数现在位于 from tensorflow.python.keras.utils.layer_utils import count_params
sum(count_params(layer) for layer in model.trainable_weights) 以获得更多可读性...我不清楚对 set 的调用完成了什么。
对于 TensorFlow 2.0:
import tensorflow.keras.backend as K
trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
【讨论】:
对于 tensorflow.keras,这对我有用。它来自 layer_utils.py 中函数 print_layer_summary_with_connections() 的 tensorflow github 代码
import numpy as np
from tensorflow.python.util import object_identity
def count_params(weights):
return int(sum(np.prod(p.shape.as_list())
for p in object_identity.ObjectIdentitySet(weights)))
if hasattr(model, '_collected_trainable_weights'):
trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model.trainable_weights)
print (trainable_count)
【讨论】:
另一种计算可训练参数的方法是:
model.count_params()
【讨论】: