在上一篇文章中,笔记介绍了关于kNN算法的原理和实现过程,在这一节,我们来实现手写字符识别,数据集请参见机器学习实战的数据集
首先,我们的数据集分为训练集和测试集,训练集中含有将近2000个数据文档(在这个文档中将一个32*32的矢量图),故为了在计算机中进行处理,我们需要先打开文件,循环读出文件前32行,并将每行的头32个字符值存储在Numpy数组中,最后返回数组,代码如下:
def img2vector(filename): returnVect=zeros((1,1024)) fr=open(filename) for i in range(32): lineStr=fr.readline() for j in range(32): returnVect[0,32*i+j]=int(lineStr[j]) return returnVect
实现这一步后,我们只需在实现测试代码即可
def handwritingClassTest(): hwLabels=[] trainingFileList=os.listdir(\'trainingDigits\')#get 1 print trainingFileList m=len(trainingFileList) print m trainingMat=zeros((m,1024)) for i in range(m): fileNameStr=trainingFileList[i] fileStr=fileNameStr.split(\'.\')[0] classNumStr=int(fileStr.split(\'_\')[0]) hwLabels.append(classNumStr) trainingMat[i,:]=img2vector(\'trainingDigits/%s\' %fileNameStr) testFileList=os.listdir(\'testDigits\') errorCount=0.0 mTest=len(testFileList) for i in range(mTest): fileNameStr=testFileList[i] fileStr=fileNameStr.split(\'.\')[0] classNumStr=int(fileStr.split(\'_\')[0]) vectorUnderTest=img2vector(\'testDigits/%s\' %fileNameStr) classifierResult=classify0(vectorUnderTest,trainingMat,hwLabels,3) print "the classifier came back with:%d,the real answer is:%d" %(classifierResult,classNumStr) if(classifierResult!=classNumStr):errorCount+=1.0 print "\nthe total number of errors is:%d" %errorCount print "\nthe total error rate od:%f:" %(errorCount/float(mTest))
结束