【问题标题】:How get trainable variable count when using Tf Estimator?使用 Tf Estimator 时如何获得可训练的变量计数?
【发布时间】:2019-04-08 12:07:52
【问题描述】:

我使用 tf 估计器框架创建了一个 CNN 分类器模型。但是,我无法访问模型中定义的变量。 tf.trainable_variables() 总是返回 0。 如何使用 tf 估计器访问变量?特别是,我怎样才能得到参数总数的计数(将所有变量的维度相加。

谢谢, 哈罗德

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    如上所述,您应该使用:

    获得变量后,您可以使用以下方法之一来获取估计器参数的总数。

    • 将每个变量的形状暗淡乘以numpy.prod,然后求和:

      sum([np.prod(est.get_variable_value(var).shape) for var in est.get_variable_names()])

    • 或者将变量的大小与numpy.ndarray.size相加,然后相加:

      sum([est.get_variable_value(var).size for var in est.get_variable_names()])

    【讨论】:

      【解决方案2】:

      你可以使用get_variable_names()来获取所有的变量名, 并使用get_variable_value(name)按名称获取变量值。

      请使用您的代码:

      estimator = tf.estimator.Estimator(...)
      params = estimator.get_variable_names()
      for p in params:
          print(p, estimator.get_variable_value(p).shape)
      

      更多信息是https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#get_variable_nameshttps://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#get_variable_value.

      注意:必须先创建图表,然后才能获取变量。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2018-09-12
        • 1970-01-01
        • 1970-01-01
        • 2016-09-16
        • 2020-09-11
        • 2018-06-30
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多