-
# -*- coding:utf-8 -*-
-
# Filename: train2.2.py
-
# Author:hankcs
-
# Date: 2015/1/31 15:15
-
import numpy as np
-
from matplotlib import pyplot as plt
-
from matplotlib import animation
-
-
training_set = np.array([[[3, 3], 1], [[4, 3], 1], [[1, 1], -1], [[5, 2], -1]]) #训练样本
-
-
a = np.zeros(len(training_set), np.float) #矩阵a的长度为训练集样本数,类型为float
-
b = 0.0 #参数初始值为0
-
Gram = None #Gram矩阵
-
y = np.array(training_set[:, 1]) #y=[1 1 -1 -1]
-
x = np.empty((len(training_set), 2), np.float) #x为4*2的矩阵
-
for i in range(len(training_set)): #x=[[3., 3.], [4., 3.], [1., 1.], [5., 2.]]
-
x[i] = training_set[i][0]
-
history = [] #history记录每次迭代结果
-
-
def cal_gram():
-
"""
-
计算Gram矩阵
-
:return:
-
"""
-
g = np.empty((len(training_set), len(training_set)), np.int)
-
for i in range(len(training_set)):
-
for j in range(len(training_set)):
-
g[i][j] = np.dot(training_set[i][0], training_set[j][0]) #G=[xi*xj]
-
return g
-
-
-
def update(i):
-
"""
-
随机梯度下降更新参数
-
:param i:
-
:return:
-
"""
-
global a, b
-
a[i] += 1 #根据误分类点更新参数
-
b = b + 1 * y[i] #这里1是学习效率η
-
history.append([np.dot(a * y, x), b]) #history记录每次迭代结果
-
print a, b #输出每次迭代结果
-
-
-
#计算yi(Gram*xi+b),用来判断是否是误分类点
-
def cal(i):
-
global a, b, x, y
-
res = np.dot(a * y, Gram[i])
-
res = (res + b) * y[i] #返回
-
return res
-
-
-
#检查是否已经正确分类
-
def check():
-
global a, b, x, y
-
flag = False
-
for i in range(len(training_set)): #遍历每个点
-
if cal(i) <= 0: #如果yi(Gram*xi+b)<=0.则是误分类点
-
flag = True
-
update(i) #用误分类点更新参数
-
if not flag: #如果已正确分类
-
w = np.dot(a * y, x) #计算w
-
print "RESULT: w: " + str(w) + " b:" + str(b) #输出最后结果
-
return False
-
return True
-
-
-
if __name__ == "__main__":
-
Gram = cal_gram() #初始化 Gram矩阵
-
for i in range(1000): #迭代1000次
-
if not check(): break #如果已正确分类则结束循环
-
-
#以下代码是将迭代过程可视化,数据来源于history
-
# first set up the figure, the axis, and the plotelement we want to animate
-
fig = plt.figure()
-
ax = plt.axes(xlim=(0, 2), ylim=(-2, 2))
-
line, = ax.plot([], [], 'g', lw=2)
-
label = ax.text([], [], '')
-
-
# initialization function: plot the background of eachframe
-
def init():
-
line.set_data([], [])
-
x, y, x_, y_ = [], [], [], []
-
for p in training_set:
-
if p[1] > 0:
-
x.append(p[0][0])
-
y.append(p[0][1])
-
else:
-
x_.append(p[0][0])
-
y_.append(p[0][1])
-
-
plt.plot(x, y, 'bo', x_, y_, 'rx')
-
plt.axis([-6, 6, -6, 6])
-
plt.grid(True)
-
plt.xlabel('x')
-
plt.ylabel('y')
-
plt.title('PerceptronAlgorithm 2 (www.hankcs.com)')
-
return line, label
-
-
-
# animation function. this is called sequentially
-
def animate(i):
-
global history, ax, line, label
-
-
w = history[i][0]
-
b = history[i][1]
-
if w[1] == 0: return line, label
-
x1 = -7.0
-
y1 = -(b + w[0] * x1) / w[1]
-
x2 = 7.0
-
y2 = -(b + w[0] * x2) / w[1]
-
line.set_data([x1, x2], [y1, y2])
-
x1 = 0.0
-
y1 = -(b + w[0] * x1) / w[1]
-
label.set_text(str(history[i][0]) + ' ' + str(b))
-
label.set_position([x1, y1])
-
return line, label
-
-
# call the animator. blit=true means only re-draw the parts that have changed.
-
anim =animation.FuncAnimation(fig, animate, init_func=init, frames=len(history), interval=1000, repeat=True,
-
blit=True)
-
plt.show()
-
#anim.save('D:/perceptron2.gif',fps=2, writer='imagemagick')
|