【问题标题】:How does the optimizer in tensorflow access the variables created in a separate functiontensorflow中的优化器如何访问单独函数中创建的变量
【发布时间】:2018-03-06 12:36:32
【问题描述】:

代码中感兴趣的行后跟多个哈希 (#) 符号

为了理解目的,我在 tensorflow 中运行了一个简单的线性回归。我使用的代码是:

def generate_dataset():
#y = 2x+e where is the normally distributed error
x_batch = np.linspace(-1,1,101)
y_batch = 2*x_batch +np.random.random(*x_batch.shape)*0.3
return x_batch, y_batch

def linear_regression():   ##################
x = tf.placeholder(tf.float32, shape = (None,), name = 'x')
y = tf.placeholder(tf.float32, shape = (None,), name = 'y')
with tf.variable_scope('lreg') as scope: ################
    w = tf.Variable(np.random.normal()) ##################
    y_pred = tf.multiply(w,x)
    loss = tf.reduce_mean(tf.square(y_pred - y))
return x,y, y_pred, loss
def run():
x_batch, y_batch = generate_dataset()
x, y, y_pred, loss = linear_regression()
optimizer = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

init = tf.global_variables_initializer()
with tf.Session() as session:
    session.run(init) 
    feed_dict = {x: x_batch, y: y_batch}
    for _ in range(30):
        loss_val, _ = session.run([loss, optimizer], feed_dict)
        print('loss:', loss_val.mean())
    y_pred_batch = session.run(y_pred, {x:x_batch})

    print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) ############
    print(session.run(fetches = [w])) #############
run()      

我似乎无法通过对“w”或“lreg/w”的获取调用来获取变量的值(它实际上是一个操作吗?)“w”,如果我理解正确的话是由于 'w' 是在 linear_regression() 中定义的,并且它不会将其命名空间借给 run()。但是,我可以通过对其变量名称“lreg/vairable:0”的 fetch 调用来访问“w”。优化器工作正常,更新完美应用

优化器如何访问 'w' 并应用更新,如果您能稍微了解一下线性回归()和 run() 之间的操作 'w' 是如何共享的,那就太好了

【问题讨论】:

    标签: python tensorflow namespaces


    【解决方案1】:

    您创建的每个操作和变量都是张量流graph 中的一个节点。当您没有明确创建图表时,例如在您的情况下,则会使用默认图表。

    这一行将 w 添加到默认图形中。

     w = tf.Variable(np.random.normal())
    

    此行访问图形以执行计算

    loss_val, _ = session.run([loss, optimizer], feed_dict)
    

    你可以像这样检查图表

    tf.get_default_graph().as_graph_def()
    

    【讨论】:

    • 非常感谢您的回复。我有一个后续问题:为什么从 run() 运行时 print(session.run(fetches = [w])) 会抛出错误? NameError:未定义名称“w”。想提醒您 print(session.run(fetches = ['lreg/variable:0']) 确实为我获取了 'w' 的值
    • 您必须在脑海中将python变量和tensorflow变量分开。仅仅因为 Tensorflow 图中有一个名为 w 的变量,并不意味着 python 在当前范围内定义了一个名为 w 的变量。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2016-11-20
    • 1970-01-01
    • 2017-05-22
    • 1970-01-01
    • 2018-02-28
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多