【发布时间】:2017-10-30 22:38:05
【问题描述】:
我在 youtube 上偶然发现了 andrew Ng 的课程并观看了以下video (2min03)
我尝试实现以下函数 之后绘制它,让他在幻灯片中显示,但似乎我只得到一个斜率。此外,我试图将 theta0 和 theta1 与 JList 绘制为网格图,但我一直得到一个错误的图,欢迎任何关于如何获得与他的视频中相同的图的帮助
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from fractions import Fraction
from matplotlib import cm
theta0=[]
theta1=[]
JList=[]
csv=np.genfromtxt('ex1.data', delimiter=",")
x=csv[:,0]
y=csv[:,1]
for a in np.arange(-50,50,1):
for b in np.arange(1,10,1):
theta0.append(a)
theta1.append(b)
result=0
for c in range(len(x)):
sum=float(a+(b*x[c])-(y[c]))
np.power(sum,2)
result+=float(sum)
Jt=0
Jt=Fraction(1,2*len(x))
Jt=Jt*result
JList.append(int(Jt))
fig = plt.figure()
ax = fig.gca(projection='3d')
X, Y = np.meshgrid(theta1, theta0)
surf = ax.plot_surface(X, Y, JList, cmap=cm.coolwarm,
linewidth=0, antialiased=False)
plt.show()
这是 ex1.data 的示例
6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
【问题讨论】:
-
你能提供
ex1.data吗?你遇到了什么错误? -
@RafaelBarros 我添加了 ex1.data 和我得到的图形示例。
标签: python machine-learning linear-regression