【问题标题】:XOR Neural Network using TensorFlow in Python在 Python 中使用 TensorFlow 的 XOR 神经网络
【发布时间】:2017-11-15 07:36:11
【问题描述】:

我目前正在学习神经网络背后的理论,我想学习如何编写此类模型。因此我开始研究 TensorFlow。

我找到了一个非常有趣的应用程序,我想编写它,但我目前无法让它工作,我真的不知道为什么!

示例来自Deep Learning, Goodfellow et al 2016第171-177页。

import tensorflow as tf

T = 1.
F = 0.
train_in = [
    [T, T],
    [T, F],
    [F, T],
    [F, F],
]
train_out = [
    [F],
    [T],
    [T],
    [F],
]
w1 = tf.Variable(tf.random_normal([2, 2]))
b1 = tf.Variable(tf.zeros([2]))

w2 = tf.Variable(tf.random_normal([2, 1]))
b2 = tf.Variable(tf.zeros([1]))

out1 = tf.nn.relu(tf.matmul(train_in, w1) + b1)
out2 = tf.nn.relu(tf.matmul(out1, w2) + b2)

error = tf.subtract(train_out, out2)
mse = tf.reduce_mean(tf.square(error))

train = tf.train.GradientDescentOptimizer(0.01).minimize(mse)

sess = tf.Session()
tf.global_variables_initializer()

err = 1.0
target = 0.01
epoch = 0
max_epochs = 1000

while err > target and epoch < max_epochs:
    epoch += 1
    err, _ = sess.run([mse, train])

print("epoch:", epoch, "mse:", err)
print("result: ", out2)

我在运行代码时在 Pycharm 中收到以下错误消息:Screenshot

【问题讨论】:

    标签: python-3.x tensorflow neural-network pycharm xor


    【解决方案1】:

    为了运行初始化操作,你应该写:

    sess.run(tf.global_variables_initializer())
    

    代替:

    tf.global_variables_initializer()
    

    这是一个工作版本:

    import tensorflow as tf
    
    T = 1.
    F = 0.
    train_in = [
        [T, T],
        [T, F],
        [F, T],
        [F, F],
    ]
    train_out = [
        [F],
        [T],
        [T],
        [F],
    ]
    w1 = tf.Variable(tf.random_normal([2, 2]))
    b1 = tf.Variable(tf.zeros([2]))
    
    w2 = tf.Variable(tf.random_normal([2, 1]))
    b2 = tf.Variable(tf.zeros([1]))
    
    out1 = tf.nn.relu(tf.matmul(train_in, w1) + b1)
    out2 = tf.nn.relu(tf.matmul(out1, w2) + b2)
    
    error = tf.subtract(train_out, out2)
    mse = tf.reduce_mean(tf.square(error))
    
    train = tf.train.GradientDescentOptimizer(0.01).minimize(mse)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    err = 1.0
    target = 0.01
    epoch = 0
    max_epochs = 1000
    
    while err > target and epoch < max_epochs:
        epoch += 1
        err, _ = sess.run([mse, train])
    
    print("epoch:", epoch, "mse:", err)
    print("result: ", out2)
    

    【讨论】:

      猜你喜欢
      • 2023-03-23
      • 2016-02-18
      • 2014-05-09
      • 2012-03-28
      • 2015-07-21
      • 2020-02-10
      • 1970-01-01
      • 2015-02-04
      • 2019-03-15
      相关资源
      最近更新 更多