【发布时间】:2015-12-18 20:21:07
【问题描述】:
我正在为正态分布编写贝赛分类器。我在 python 和 MATLAB 中都有几乎相同的代码。然而,MATLAB 代码的运行速度比我的 Python 脚本快 50 倍。我是 Python 新手,所以也许我做错了什么。我假设它在我循环数据集的地方。
可能 numpy.argmax() 比 [~,idx]=max() 慢很多?循环遍历数据框很慢?字典使用不当(之前我尝试过一个对象,它甚至很慢)?
欢迎任何建议。
Python 代码
import numpy as np
import pandas as pd
#import the data as a data frame
train_df = pd.read_table('hw1_traindata.txt',header = None)#training
train_df.columns = [1, 2] #rename column titles
这里的数据是 2 列(300 行/样本用于训练,300000 用于测试)。这是函数参数; mi 和 Si 是样本均值和协方差。
case3_p = {'w': [], 'w0': [], 'W': []}
case3_p['w']={1:S1.I*m1,2:S2.I*m2,3:S3.I*m3}
case3_p['w0']={1: -1.0/2.0*(m1.T*S1.I*m1)-
1.0/2.0*np.log(np.linalg.det(S1)),
2: -1.0/2.0*(m2.T*S2.I*m2)-1.0/2.0*np.log(np.linalg.det(S2)),
3: -1.0/2.0*(m3.T*S3.I*m3)-1.0/2.0*np.log(np.linalg.det(S3))}
case3_p['W']={1: -1.0/2.0*S1.I,
2: -1.0/2.0*S2.I,
3: -1.0/2.0*S3.I}
#W1=-1.0/2.0*S1.I
#w1_3=S1.I*m1
#w01_3=-1.0/2.0*(m1.T*S1.I*m1)-1.0/2.0*np.log(np.linalg.det(S1))
def g3(x,W,w,w0):
return x.T*W*x+w.T*x+w0
这是分类器/循环
train_df['case3'] = 0
for i in range(train_df.shape[0]):
x = np.mat(train_df.loc[i,[1, 2]]).T#observation
#case 3
vals = [g3(x,case3_p['W'][1],case3_p['w'][1],case3_p['w0'][1]),
g3(x,case3_p['W'][2],case3_p['w'][2],case3_p['w0'][2]),
g3(x,case3_p['W'][3],case3_p['w'][3],case3_p['w0'][3])]
train_df.loc[i,'case3'] = np.argmax(vals) + 1 #add one to make it the class value
对应的MATLAB代码
train = load('hw1_traindata.txt');
判别函数
W1=-1/2*S1^-1;%there isn't one for the other cases
w1_3=S1^-1*m1';%fix the transpose thing
w10_3=-1/2*(m1*S1^-1*m1')-1/2*log(det(S1));
g1_3=@(x) x'*W1*x+w1_3'*x+w10_3';
W2=-1/2*S2^-1;
w2_3=S2^-1*m2';
w20_3=-1/2*(m2*S2^-1*m2')-1/2*log(det(S2));
g2_3=@(x) x'*W2*x+w2_3'*x+w20_3';
W3=-1/2*S3^-1;
w3_3=S3^-1*m3';
w30_3=-1/2*(m3*S3^-1*m3')-1/2*log(det(S3));
g3_3=@(x) x'*W3*x+w3_3'*x+w30_3';
分类器
case3_class_tr = Inf(size(act_class_tr));
for i=1:length(train)
x=train(i,:)';%current sample
%case3
vals = [g1_3(x),g2_3(x),g3_3(x)];%compute discriminant function value
[~, case3_class_tr(i)] = max(vals);%get location of max
end
【问题讨论】:
标签: python performance matlab numpy pandas