wuxiaoqian

在上一篇文章中,笔记介绍了关于kNN算法的原理和实现过程,在这一节,我们来实现手写字符识别,数据集请参见机器学习实战的数据集

首先,我们的数据集分为训练集和测试集,训练集中含有将近2000个数据文档(在这个文档中将一个32*32的矢量图),故为了在计算机中进行处理,我们需要先打开文件,循环读出文件前32行,并将每行的头32个字符值存储在Numpy数组中,最后返回数组,代码如下:

def img2vector(filename):
    returnVect=zeros((1,1024))
    fr=open(filename)
    for i in range(32):
        lineStr=fr.readline()
        for j in range(32):
            returnVect[0,32*i+j]=int(lineStr[j])
    return returnVect

实现这一步后,我们只需在实现测试代码即可

def handwritingClassTest():
    hwLabels=[]
    trainingFileList=os.listdir(\'trainingDigits\')#get 1
    print trainingFileList
    m=len(trainingFileList)
    print m
    trainingMat=zeros((m,1024))
    for i in range(m):
        fileNameStr=trainingFileList[i]
        fileStr=fileNameStr.split(\'.\')[0]
        classNumStr=int(fileStr.split(\'_\')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:]=img2vector(\'trainingDigits/%s\' %fileNameStr)
    testFileList=os.listdir(\'testDigits\')
    errorCount=0.0
    mTest=len(testFileList)
    for i in range(mTest):
        fileNameStr=testFileList[i]
        fileStr=fileNameStr.split(\'.\')[0]
        classNumStr=int(fileStr.split(\'_\')[0])
        vectorUnderTest=img2vector(\'testDigits/%s\' %fileNameStr)
        classifierResult=classify0(vectorUnderTest,trainingMat,hwLabels,3)
        print "the classifier came back with:%d,the real answer is:%d" %(classifierResult,classNumStr)
        if(classifierResult!=classNumStr):errorCount+=1.0
    print "\nthe total number of errors is:%d" %errorCount
    print "\nthe total error rate od:%f:" %(errorCount/float(mTest)) 

结束

 

分类:

技术点:

相关文章: