【问题标题】:For each row, what is the fastest way to find the column holding nth element that is not NaN?对于每一行,找到第 n 个非 NaN 元素的列的最快方法是什么?
【发布时间】:2015-11-04 10:10:12
【问题描述】:

我有一个 Python pandas DataFrame,其中每个元素都是浮点数或 NaN。 对于每一行,我需要找到包含该行第 n 个数字的列。也就是说,我需要获取包含非 NaN 行的第 n 个元素的列。我知道第 n 个这样的专栏总是存在的。

因此,如果 n 为 4,并且名为 myDF 的 pandas 数据帧如下:

      10   20   30   40   50   60  70  80  90  100

'A'  4.5  5.5  2.5  NaN  NaN  2.9 NaN NaN 1.1 1.8
'B'  4.7  4.1  NaN  NaN  NaN  2.0 1.2 NaN NaN NaN
'C'  NaN  NaN  NaN  NaN  NaN  1.9 9.2 NaN 4.4 2.1
'D'  1.1  2.2  3.5  3.4  4.5  NaN NaN NaN 1.9 5.5

我想获得:

'A'  60
'B'  70
'C'  100 
'D'  40

我能做到:

import pandas as pd
import math

n = some arbitrary int
for row in myDF.indexes:
   num_not_NaN = 0   
   for c in myDF.columns:    
      if math.isnan(myDF[c][row]) == False: 
           num_not_NaN +=1
      if num_not_NaN==n:
           print row, c
           break

我确信这很慢而且不是 Pythonic。如果我要处理非常大的 DataFrame 和较大的 n 值,是否有更快的方法?

【问题讨论】:

    标签: python performance pandas dataframe nan


    【解决方案1】:

    如果您的目标是速度,最好尽可能利用 Pandas 的矢量化方法:

    >>> (df.notnull().cumsum(axis=1) == 4).idxmax(axis=1) # replace 4 with any number you like
    'A'     60
    'B'     70
    'C'    100
    'D'     40
    dtype: object
    

    其他答案很好,语法上可能更清晰一些。就速度而言,对于您的小示例,它们之间没有太大区别。但是,对于稍大的 DataFrame,向量化方法已经快了大约 60 倍:

    >>> df2 = pd.concat([df]*1000) # 4000 row DataFrame
    >>> %timeit df2.apply(lambda row: get_nth(row, n), axis=1)
    1 loops, best of 3: 749 ms per loop
    
    >>> %timeit df2.T.apply(lambda x: x.dropna()[n-1:].index[0])
    1 loops, best of 3: 673 ms per loop
    
    >>> %timeit (df2.notnull().cumsum(1) == 4).idxmax(axis=1)
    100 loops, best of 3: 10.5 ms per loop
    

    【讨论】:

    • 非常感谢。这将大大加快我的速度。
    【解决方案2】:

    您可以转置 df 并应用 lambda 来删除 NaN 行,从第 4 个值开始切片并返回第一个有效索引:

    In [72]:
    n=4
    
    df.T.apply(lambda x: x.dropna()[n-1:].index[0])
    Out[72]:
    'A'     60
    'B'     70
    'C'    100
    'D'     40
    dtype: object
    

    【讨论】:

      【解决方案3】:

      您可以创建一个函数,然后将其传递给lambda 函数。

      该函数过滤系列中的空值,然后返回n 项的索引值(如果索引长度小于n,则返回None)。

      lambda 函数需要axis=1 以确保将其应用于 DataFrame 的每一行。

      def get_nth(series, n):
          s = series[series.notnull()]
          if len(s) >= n:
              return s.index[n - 1]
      
      >>> n = 4
      >>> df.apply(lambda row: get_nth(row, n), axis=1)
      A     60
      B     70
      C    100
      D     40
      dtype: object
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2014-11-20
        • 1970-01-01
        • 2019-10-11
        • 1970-01-01
        • 2013-03-28
        • 2014-08-18
        • 1970-01-01
        相关资源
        最近更新 更多