【发布时间】:2019-04-08 12:07:52
【问题描述】:
我使用 tf 估计器框架创建了一个 CNN 分类器模型。但是,我无法访问模型中定义的变量。 tf.trainable_variables() 总是返回 0。 如何使用 tf 估计器访问变量?特别是,我怎样才能得到参数总数的计数(将所有变量的维度相加。
谢谢, 哈罗德
【问题讨论】:
标签: python tensorflow
我使用 tf 估计器框架创建了一个 CNN 分类器模型。但是,我无法访问模型中定义的变量。 tf.trainable_variables() 总是返回 0。 如何使用 tf 估计器访问变量?特别是,我怎样才能得到参数总数的计数(将所有变量的维度相加。
谢谢, 哈罗德
【问题讨论】:
标签: python tensorflow
如上所述,您应该使用:
tf.estimator.Estimator.get_variable_names() 以获取所有估算器变量tf.estimator.Estimator.get_variable_value(name) 以获取变量值获得变量后,您可以使用以下方法之一来获取估计器参数的总数。
将每个变量的形状暗淡乘以numpy.prod,然后求和:
sum([np.prod(est.get_variable_value(var).shape) for var in est.get_variable_names()])
或者将变量的大小与numpy.ndarray.size相加,然后相加:
sum([est.get_variable_value(var).size for var in est.get_variable_names()])
【讨论】:
你可以使用get_variable_names()来获取所有的变量名,
并使用get_variable_value(name)按名称获取变量值。
请使用您的代码:
estimator = tf.estimator.Estimator(...)
params = estimator.get_variable_names()
for p in params:
print(p, estimator.get_variable_value(p).shape)
更多信息是https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#get_variable_names 和 https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#get_variable_value.
注意:必须先创建图表,然后才能获取变量。
【讨论】: