【发布时间】:2018-05-04 21:56:23
【问题描述】:
我正在使用 tensorflow 进行线性回归。这里我遇到了一个问题:
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (8,6)
data = pd.read_csv('./data.csv')
xs = data["A"][:100]
ys = data["B"][:100]
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')
W = tf.Variable(tf.random_normal([1]),name = 'weight')
b = tf.Variable(tf.random_normal([1]),name = 'bias')
Y_pred = tf.add(tf.multiply(X,W), b)
sample_num = xs.shape[0]
loss = tf.reduce_sum(tf.pow(Y_pred - Y,2))/sample_num
learning_rate = 0.0001
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
n_samples = xs.shape[0]
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(100):
for x,y in zip(xs,ys):
_, l = sess.run([optimizer, loss], feed_dict={X: x, Y:y})
W, b = sess.run([W, b])
plt.plot(xs, ys, 'bo', label='Real data')
plt.plot(xs, xs*W + b, 'r', label='Predicted data')
plt.legend()
plt.show()
data.csv 是here。
【问题讨论】:
-
你试过提高学习率吗? 0.0001 相当小。你得到什么样的损失值?
-
@Sunreef 我试图提高学习率,但变量“W”、“b”和“损失”将变为“无”。不会绘制预测数据。我不知道是什么问题。
-
@Sunreef 我的tensorflow是pip安装的,总是显示很多警告。这些会影响结果吗?
标签: python tensorflow regression