【问题标题】:RAPIDS: How to use one dataframe in a UDF called with apply_rows of another dataframe?RAPIDS:如何在使用另一个数据帧的 apply_rows 调用的 UDF 中使用一个数据帧?
【发布时间】:2021-04-20 21:51:15
【问题描述】:

对于数据框 A 中的每一行,我需要查询 DF B。我需要执行以下操作:按列 b1 (B.b1) 中的值过滤 B 行,这些值在列 A.a1 和A.a2 并将组合值分配给 A.a3 列。

在熊猫中会是这样的:

A.a1 = B[(B.b1>A.a2) & (B.b1<A.a3)]['b2'].values

我尝试在 UDF 的函数参数中传递数据帧,但出现错误:

ValueError: Cannot determine Numba type of <class 'cudf.core.dataframe.DataFrame'>

下面是一个使用 Pandas 的 Python 示例。

toyevents = pd.DataFrame.from_dict({'end': {0: 8.748356416,
         1: 8.752231441000001,
         2: 8.756627850000001,
         3: 8.760818359,
         4: 8.765967569,
         5: 8.77041589,
         6: 8.774226174,
         7: 8.776358813,
         8: 8.77866835,
         9: 8.780719302000001},
 'name_id': {0: 18452.0,
             1: 20586.0,
             2: 20491.0,
             3: 20610.0,
             4: 20589.0,
             5: 20589.0,
             6: 19165.0,
             7: 20589.0,
             8: 20586.0,
             9: 19064.0},
 'start': {0: 8.748299848,
           1: 8.752229263,
           2: 8.756596980000001,
           3: 8.760816603,
           4: 8.765957310000001,
           5: 8.770381615,
           6: 8.77414259,
           7: 8.776349745000001,
           8: 8.778666861000001,
           9: 8.780674982}})

toynvtx = pd.DataFrame.from_dict({'NvtxEvent.Text': {0: 'Iteration 32',
                    1: 'FWD pass',
                    2: 'Prediction and loss',
                    3: 'BWD pass',
                    4: 'Optimizer update'},
 'end': {0: 8.802574018000001,
         1: 8.771325765,
         2: 8.771688249,
         3: 8.792846429,
         4: 8.802333183},
 'start': {0: 8.744061385,
           1: 8.747272157000001,
           2: 8.771329333,
           3: 8.771691628000001,
           4: 8.792851876}})

# Search NVTX ranges encompassing [start,end] range.
def pickNVTX(r,nvtx):
    start = r['start']
    end = r['end']
    start_early = nvtx[nvtx['start'] <= start]
    end_later = start_early[start_early['end'] >= end]
    return ','.join(end_later['NvtxEvent.Text'])

# Using apply()
toyevents.loc[:,'nvtx'] = toyevents_.apply(pickNVTX,nvtx=toynvtx,axis=1)

# Method 2. Using iterrows()
for i, row in toyevents.iterrows():
    toyevents.loc[i, 'nvtx'] = ','.join(
        toynvtx[(toynvtx.start <= row.start)
                & (toynvtx.end >= row.end)]['NvtxEvent.Text'].values)

【问题讨论】:

  • 如果你运行A.a1 = B[(B.b1&gt;A.a2) &amp; (B.b1&lt;A.a3)]['b2'].values会发生什么?为了让社区为您提供最好的帮助,请提供一个最小的、可重现的示例。
  • @NickBecker 感谢您的帮助! A.a1 = B[(B.b1&gt;A.a2) &amp; (B.b1&lt;A.a3)]['b2'].values 产生 ValueError: Can only compare identically-labeled Series objects。我将在下面的评论中发布一个有效的 Python 示例。
  • 代码太长,无法评论,所以我更新了原帖。

标签: python pandas rapids cudf


【解决方案1】:

您可能希望对此类问题使用不等式(条件)连接。 pandas、cuDF 或 BlazingSQL 目前不支持此功能。

如果您的数据不是很大,您可以结合使用交叉连接、布尔掩码和 groupby collect_list 来实现。如果您提供第二个数据帧作为参数,UDF 也可能会起作用,以便您可以对其进行索引并循环(但这会变得混乱且效率低下)。

你的例子的输出是:

        end  name_id     start                   nvtx
0  8.748356  18452.0  8.748300  Iteration 32,FWD pass
1  8.752231  20586.0  8.752229  Iteration 32,FWD pass
2  8.756628  20491.0  8.756597  Iteration 32,FWD pass
3  8.760818  20610.0  8.760817  Iteration 32,FWD pass
4  8.765968  20589.0  8.765957  Iteration 32,FWD pass
5  8.770416  20589.0  8.770382  Iteration 32,FWD pass
6  8.774226  19165.0  8.774143  Iteration 32,BWD pass
7  8.776359  20589.0  8.776350  Iteration 32,BWD pass
8  8.778668  20586.0  8.778667  Iteration 32,BWD pass
9  8.780719  19064.0  8.780675  Iteration 32,BWD pass

以下代码将提供相同的输出,使用 List 列而不是字符串列。

# put the example data on the GPU
toyevents = cudf.from_pandas(toyevents)
toynvtx = cudf.from_pandas(toynvtx)
​
# cross join
toyevents['key'] = 1
toynvtx['key'] = 1
merged = toyevents.merge(toynvtx, how="outer", on="key")
del merged["key"]

# filter
mask = (merged.start_y <= merged.start_x) & (merged.end_y >= merged.end_x)
del merged["start_y"], merged["end_y"]

# collect list
merged[mask].groupby(["end_x", "name_id", "start_x"])["NvtxEvent.Text"].agg(list)
end_x     name_id  start_x 
8.748356  18452.0  8.748300    [Iteration 32, FWD pass]
8.752231  20586.0  8.752229    [Iteration 32, FWD pass]
8.756628  20491.0  8.756597    [Iteration 32, FWD pass]
8.760818  20610.0  8.760817    [Iteration 32, FWD pass]
8.765968  20589.0  8.765957    [Iteration 32, FWD pass]
8.770416  20589.0  8.770382    [Iteration 32, FWD pass]
8.774226  19165.0  8.774143    [Iteration 32, BWD pass]
8.776359  20589.0  8.776350    [Iteration 32, BWD pass]
8.778668  20586.0  8.778667    [Iteration 32, BWD pass]
8.780719  19064.0  8.780675    [Iteration 32, BWD pass]
Name: NvtxEvent.Text, dtype: list

【讨论】:

  • 你的方法对我有用。谢谢!如果出现内存错误,我认为可以将第一个带有事件的较大表拆分为几个较小的表,并将它们用于循环中的交叉连接。
  • 太棒了!是的,您没有理由不能对较大的表进行分块并进行几个较小的交叉连接,然后合并结果
猜你喜欢
  • 2017-05-14
  • 2017-11-16
  • 1970-01-01
  • 2023-04-04
  • 1970-01-01
  • 2015-12-19
  • 2016-09-15
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多