【问题标题】:RNN Cell not present in tf.get_collectiontf.get_collection 中不存在 RNN 单元
【发布时间】:2016-12-20 21:20:06
【问题描述】:

使用tf.get_collection() 时,RNN 单元不显示。我错过了什么?

import tensorflow as tf
print(tf.__version__)

rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

打印出来

0.12.0
[]
[<tensorflow.python.ops.variables.Variable object at 0x0000027961250B70>]

Windows 10、Python 3.5

【问题讨论】:

  • 你需要__call__ LSTMCell 来创建它的变量

标签: python tensorflow recurrent-neural-network


【解决方案1】:

您没有在LSTMCell 上运行__call__,这就是您看不到变量的原因。试试这个(我假设batch_size=10rnn_size=16

import tensorflow as tf
print(tf.__version__)

rnn_cell = tf.nn.rnn_cell.LSTMCell(16)
a = tf.placeholder(tf.float32, [10, 16])
zero = rnn_cell.zero_state(10,tf.float32)
# The variables are created in the following __call__
b = rnn_cell(a, zero)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

other_var = tf.Variable(0)
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2016-12-05
    • 1970-01-01
    • 2019-07-21
    • 2019-01-13
    • 1970-01-01
    • 2022-11-30
    • 1970-01-01
    相关资源
    最近更新 更多