【问题标题】:Is numpy.argmax slower than MATLAB [~,idx] = max()?numpy.argmax 是否比 MATLAB [~,idx] = max() 慢?
【发布时间】:2015-12-18 20:21:07
【问题描述】:

我正在为正态分布编写贝赛分类器。我在 python 和 MATLAB 中都有几乎相同的代码。然而,MATLAB 代码的运行速度比我的 Python 脚本快 50 倍。我是 Python 新手,所以也许我做错了什么。我假设它在我循环数据集的地方。

可能 numpy.argmax() 比 [~,idx]=max() 慢很多?循环遍历数据框很慢?字典使用不当(之前我尝试过一个对象,它甚至很慢)?

欢迎任何建议。

Python 代码

import numpy as np
import pandas as pd

#import the data as a data frame
train_df = pd.read_table('hw1_traindata.txt',header = None)#training
train_df.columns = [1, 2] #rename column titles

这里的数据是 2 列(300 行/样本用于训练,300000 用于测试)。这是函数参数; mi 和 Si 是样本均值和协方差。

case3_p = {'w': [], 'w0': [], 'W': []}
case3_p['w']={1:S1.I*m1,2:S2.I*m2,3:S3.I*m3}
case3_p['w0']={1: -1.0/2.0*(m1.T*S1.I*m1)-

1.0/2.0*np.log(np.linalg.det(S1)),
            2: -1.0/2.0*(m2.T*S2.I*m2)-1.0/2.0*np.log(np.linalg.det(S2)),
            3: -1.0/2.0*(m3.T*S3.I*m3)-1.0/2.0*np.log(np.linalg.det(S3))}
case3_p['W']={1: -1.0/2.0*S1.I,
           2: -1.0/2.0*S2.I,
           3: -1.0/2.0*S3.I}
#W1=-1.0/2.0*S1.I
#w1_3=S1.I*m1
#w01_3=-1.0/2.0*(m1.T*S1.I*m1)-1.0/2.0*np.log(np.linalg.det(S1))    
def g3(x,W,w,w0):
    return x.T*W*x+w.T*x+w0

这是分类器/循环

train_df['case3'] = 0

for i in range(train_df.shape[0]):
    x = np.mat(train_df.loc[i,[1, 2]]).T#observation

    #case 3    
    vals = [g3(x,case3_p['W'][1],case3_p['w'][1],case3_p['w0'][1]),
            g3(x,case3_p['W'][2],case3_p['w'][2],case3_p['w0'][2]),
            g3(x,case3_p['W'][3],case3_p['w'][3],case3_p['w0'][3])]
    train_df.loc[i,'case3'] = np.argmax(vals) + 1 #add one to make it the class value

对应的MATLAB代码

train = load('hw1_traindata.txt');

判别函数

W1=-1/2*S1^-1;%there isn't one for the other cases
w1_3=S1^-1*m1';%fix the transpose thing
w10_3=-1/2*(m1*S1^-1*m1')-1/2*log(det(S1));
g1_3=@(x) x'*W1*x+w1_3'*x+w10_3';

W2=-1/2*S2^-1;
w2_3=S2^-1*m2';
w20_3=-1/2*(m2*S2^-1*m2')-1/2*log(det(S2));
g2_3=@(x) x'*W2*x+w2_3'*x+w20_3';

W3=-1/2*S3^-1;
w3_3=S3^-1*m3';
w30_3=-1/2*(m3*S3^-1*m3')-1/2*log(det(S3));
g3_3=@(x) x'*W3*x+w3_3'*x+w30_3';

分类器

case3_class_tr = Inf(size(act_class_tr));
for i=1:length(train)
    x=train(i,:)';%current sample

    %case3
    vals = [g1_3(x),g2_3(x),g3_3(x)];%compute discriminant function value
    [~, case3_class_tr(i)] = max(vals);%get location of max

end

【问题讨论】:

    标签: python performance matlab numpy pandas


    【解决方案1】:

    这真的很难说,但是直接从包中取出 Matlab 会比 Numpy 快。主要是因为它带有自己的Math Kernel Library

    50x 是否是一个合理的近似值,很难比较基本的 Numpy 与 Matlab 的 MKL。

    还有其他带有自己的 MKL 的 Python 发行版,例如 EnthoughtAnaconda

    在 Anaconda 的 MKL Optimizations 页面中,您将看到比较普通 Anaconda 和 MKL 之间差异的图表。改进不是线性的,但肯定存在。

    【讨论】:

    • MKL 可以帮助一些操作,但我认为它不会为简单的操作带来任何好处,比如找到向量的最大值。其实它看起来连这样的功能都没有。
    【解决方案2】:

    在这种情况下,最好对您的代码进行概要分析。首先我创建了一些模拟数据:

    import numpy as np
    import pandas as pd
    
    fname = 'hw1_traindata.txt'
    ar = np.random.rand(1000, 2)
    np.savetxt(fname, ar, delimiter='\t')
    
    m1, m2, m3 = [np.mat(ar).T for ar in np.random.rand(3, 2)]
    S1, S2, S3 = [np.mat(ar) for ar in np.random.rand(3, 2, 2)]
    

    然后,我将您的代码包装在一个函数中,并使用 lprun (line_profiler) IPython 魔法进行分析。结果如下:

    %lprun -f train train(fname, m1, S1, m2, S2, m3, S3)
    Timer unit: 5.59946e-07 s
    
    Total time: 4.77361 s
    File: <ipython-input-164-563f57dadab3>
    Function: train at line 1
    
    Line #   Hits     Time  Per Hit  %Time  Line Contents
    =====================================================
         1                                 def train(fname, m1, S1, m2, S2, m3, S3):
         2      1     9868   9868.0   0.1      train_df = pd.read_table(fname ,header = None)#training
         3      1      328    328.0   0.0      train_df.columns = [1, 2] #rename column titles
         4                                 
         5      1       17     17.0   0.0      case3_p = {'w': [], 'w0': [], 'W': []}
         6      1      877    877.0   0.0      case3_p['w']={1:S1.I*m1,2:S2.I*m2,3:S3.I*m3}
         7      1      356    356.0   0.0      case3_p['w0']={1: -1.0/2.0*(m1.T*S1.I*m1)-
         8                                 
         9      1      204    204.0   0.0      1.0/2.0*np.log(np.linalg.det(S1)),
        10      1      498    498.0   0.0                  2: -1.0/2.0*(m2.T*S2.I*m2)-1.0/2.0*np.log(np.linalg.det(S2)),
        11      1      502    502.0   0.0                  3: -1.0/2.0*(m3.T*S3.I*m3)-1.0/2.0*np.log(np.linalg.det(S3))}
        12      1      235    235.0   0.0      case3_p['W']={1: -1.0/2.0*S1.I,
        13      1      229    229.0   0.0                 2: -1.0/2.0*S2.I,
        14      1      230    230.0   0.0                 3: -1.0/2.0*S3.I}
        15                                 
        16      1     1818   1818.0   0.0      train_df['case3'] = 0
        17                                 
        18   1001    17409     17.4   0.2      for i in range(train_df.shape[0]):
        19   1000  4254511   4254.5  49.9          x = np.mat(train_df.loc[i,[1, 2]]).T#observation
        20                                 
        21                                         #case 3    
        22   1000   298245    298.2   3.5          vals = [g3(x,case3_p['W'][1],case3_p['w'][1],case3_p['w0'][1]),
        23   1000   269825    269.8   3.2                  g3(x,case3_p['W'][2],case3_p['w'][2],case3_p['w0'][2]),
        24   1000   274279    274.3   3.2                  g3(x,case3_p['W'][3],case3_p['w'][3],case3_p['w0'][3])]
        25   1000  3395654   3395.7  39.8          train_df.loc[i,'case3'] = np.argmax(vals) + 1
        26                                 
        27      1       45     45.0   0.0      return train_df
    

    有两条线路总共占用了 90% 的时间。因此,让我们将这些行拆分一下,然后重新运行分析器:

    %lprun -f train train(fname, m1, S1, m2, S2, m3, S3)
    Timer unit: 5.59946e-07 s
    
    Total time: 6.15358 s
    File: <ipython-input-197-92d9866b57dc>
    Function: train at line 1
    
    Line #   Hits      Time  Per Hit  %Time  Line Contents
    ======================================================
    ...     
        19   1000   5292988   5293.0   48.2          thing = train_df.loc[i,[1, 2]]  # Observation
        20   1000    265101    265.1    2.4          x = np.mat(thing).T
    ...     
        26   1000    143142    143.1    1.3          index = np.argmax(vals) + 1  # Add one to make it the class value
        27   1000   4164122   4164.1   37.9          train_df.loc[i,'case3'] = index
    

    大部分时间都花在索引 Pandas 数据框上!取argmax 只占总执行时间的 1.5%。

    通过预分配train_df['case3']和使用.iloc可以在一定程度上改善这种情况:

    %lprun -f train train(fname, m1, S1, m2, S2, m3, S3)
    Timer unit: 5.59946e-07 s
    
    Total time: 3.26716 s
    File: <ipython-input-192-f6173cdf9990>
    Function: train at line 1
    
    Line #   Hits      Time  Per Hit  %Time  Line Contents
    ======= ======= ======================================
        16      1      1548   1548.0    0.0      train_df['case3'] = np.zeros(len(train_df))
    ...             
        19   1000   2608489   2608.5   44.7          thing = train_df.iloc[i,[0, 1]]  # Observation
        20   1000    228959    229.0    3.9          x = np.mat(thing).T
    ...             
        26   1000    123165    123.2    2.1          index = np.argmax(vals) + 1  # Add one to make it the class value
        27   1000   1849283   1849.3   31.7          train_df.iloc[i,2] = index
    

    尽管如此,在紧密循环中迭代 Pandas 数据帧中的单个值是一个坏主意。在这种情况下,仅使用 Pandas 加载文本数据(它非常擅长),但除此之外使用“原始”Numpy 数组。例如。使用train_data = pd.read_table(fname, header=None).values。当你到达分析阶段时,可能会回到 Pandas。

    其他一些杂碎:

    • 使用 Python 的从零开始的索引,不要特意使用 基于 1 的索引。
    • 考虑使用普通的 Numpy 数组而不是矩阵。当你使用 您倾向于将它们与数组混合在一起并且难以调试的矩阵 问题。
    • MATLAB 有一个 JIT 编译器,所以 Python 和 MATLAB 应用于循环繁重的代码。

    【讨论】:

    • 其他几点: 1:循环遍历行而不是使用python索引更容易和更快。 2:每次循环都使用字典是不必要的,也会稍微减慢速度。 3.lambdapartial可用于g3的各种功能。
    • 感谢您的信息。正是我想要的。 TheBlackCat,你能详细说明一下你的意思是循环遍历行而不是索引吗?
    • 假设你有一个 2D numpy 数组 myarr。与其使用for i in range(myarr.shape[0]): 并使用row = myarr[i, :] 获取行,不如使用for row in myarr:。或者,如果您还需要索引,可以使用for i, row in enumerate(myarr):。 Numpy 使用视图,所以row 仍然是myarr 的一部分。 row 不占用额外内存,对它的任何更改也会更改myarr 中的相应行。您可以使用 df.iterrows() 对 pandas a DataFrame 执行类似的操作,尽管这会自动为您提供索引和行:for index, row in df.iterrows():
    猜你喜欢
    • 2022-01-23
    • 2019-08-03
    • 1970-01-01
    • 1970-01-01
    • 2023-03-31
    • 2023-03-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多