【问题标题】:Variable scopes in TensorflowTensorflow 中的变量作用域
【发布时间】:2016-06-07 13:03:46
【问题描述】:

我在有效使用变量范围时遇到问题。我想为简单循环网络的权重、偏差和内部状态定义一些变量。在定义默认图形后,我调用了一次get_saver()。然后我使用tf.scan 迭代一批样本。

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3

def batch_vm2(m, x):
  [input_size, output_size] = m.get_shape().as_list()

  input_shape = tf.shape(x)
  batch_rank = input_shape.get_shape()[0].value - 1
  batch_shape = input_shape[:batch_rank]
  output_shape = tf.concat(0, [batch_shape, [output_size]])

  x = tf.reshape(x, [-1, input_size])
  y = tf.matmul(x, m)

  y = tf.reshape(y, output_shape)

  return y

def get_saver():
    with tf.variable_scope('h1') as scope:
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        saver = tf.train.Saver([weights, biases, state])
    return saver


def load(sess, saver, checkpoint_dir = None):

        print("loading a session")
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            raise Exception("no checkpoint found")
        return

def iterate_state(prev_state_tuple, input):
    with tf.variable_scope('h1') as scope:
        scope.reuse_variables()
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
        print("input: ",input.get_shape())
        matmuladd = batch_vm2(weights, input) + biases
        matmulpri = tf.Print(matmuladd,[matmuladd], message=" malmul -> ")
        #matmulvec = tf.reshape(matmuladd, [HIDDEN_1])
        #state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        print("prev state: ",prev_state_tuple.get_shape())
        unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
        prev_state = unpacked_state
        state = state.assign( 4.2*(0.9* prev_state + 0.1*matmuladd) )
        #output = tf.nn.relu(state)
        output = tf.nn.tanh(state)
        state = tf.Print(state, [state], message=" state -> ")
        output = tf.Print(output, [output], message=" output -> ")
        #output = matmulpri
        print(" state: ", state.get_shape())
        print(" output: ", output.get_shape())
        concat_result = tf.concat(0,[state, output])
        print (" concat return: ", concat_result.get_shape())
        return concat_result

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))

    saver = get_saver()
    initial_state = tf.zeros([HIDDEN_1],
                             name='initial_state')
    initial_out = tf.zeros([HIDDEN_1],
                             name='initial_out')
    #concat_tensor = tf.concat(0,[initial_state, initial_out])
    concat_tensor = tf.concat(0,[initial_state, initial_out])
    print(" init state: ",initial_state.get_shape())
    print(" init out: ",initial_out.get_shape())
    print(" concat: ",concat_tensor.get_shape())
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
    print ("scanout shape: ", scanout.get_shape())
    state, output = tf.split(1,2,scanout, name='split_scan_output')
    print(" end state: ",state.get_shape())
    print(" end out: ",output.get_shape())

    #output,state,diagnostic = create_graph(inputs, state, prev_state)

    sess = tf.Session()
    # Run the Op to initialize the variables.
    sess.run(tf.initialize_all_variables())
    if False:
        load(sess, saver)
    iter_ = data_iter()
    for i in xrange(0, 5):
        print ("iteration: ",i)
        input_data = iter_.next()
        out,st,so = sess.run([output,state,scanout], feed_dict={ inputs: input_data})
        saver.save(sess, 'my-model', global_step=1+i)
        print("input vec: ", input_data)
        print("state vec: ", st)
        print("output vec: ", out)
        print(" end state (runtime): ",st.shape)
        print(" end out (runtime): ",out.shape)
        print(" end scanout (runtime): ",so.shape)

我希望在scan 操作中从get_variable 检索到的变量与在get_saver 调用中定义的变量相同。但是,如果我运行此示例代码,我会得到以下错误输出:

(' init state: ', TensorShape([Dimension(2)]))
(' init out: ', TensorShape([Dimension(2)]))
(' concat: ', TensorShape([Dimension(4)]))
Traceback (most recent call last):
  File "cycles_in_graphs_with_scan.py", line 88, in <module>
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 345, in scan
    back_prop=back_prop, swap_memory=swap_memory)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1873, in while_loop
    result = context.BuildLoop(cond, body, loop_vars)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 1749, in BuildLoop
    body_result = body(*vars_for_body_with_tensor_arrays)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/functional_ops.py", line 339, in compute
    a = fn(a, elems_ta.read(i))
  File "cycles_in_graphs_with_scan.py", line 47, in iterate_state
    weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 732, in get_variable
    partitioner=partitioner, validate_shape=validate_shape)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 596, in get_variable
    partitioner=partitioner, validate_shape=validate_shape)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 161, in get_variable
    caching_device=caching_device, validate_shape=validate_shape)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/variable_scope.py", line 454, in _get_single_variable
    " Did you mean to set reuse=None in VarScope?" % name)
ValueError: Variable state_scan/h1/W does not exist, disallowed. Did you mean to set reuse=None in VarScope?

知道我在这个例子中做错了什么吗?

【问题讨论】:

    标签: python-2.7 tensorflow checkpointing


    【解决方案1】:
    if False:
        load(sess, saver)
    

    这两行导致未初始化的变量。

    【讨论】:

      猜你喜欢
      • 2018-04-11
      • 2010-12-12
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2012-02-02
      • 2012-05-02
      • 2011-09-07
      相关资源
      最近更新 更多