【发布时间】:2020-09-30 23:25:10
【问题描述】:
我有预先分组的数据。具体来说,它们是 3 个不同类别的 PR 曲线,我想将它们绘制在相同的轴上:
import numpy as np
data_groups = {
'ap=0.16: cat_3 (4/19)': {
'precision': np.array([0. , 0. , 0. , 0. , 0.2 ,
0.16666667, 0.14285714, 0.25 , 0.22222222, 0.2 ,
0.18181818, 0.16666667, 0.15384615, 0.14285714, 0.13333333,
0.21052632], dtype=np.float64),
'recall': np.array([0. , 0. , 0. , 0. , 0.25, 0.25, 0.25, 0.5 , 0.5 , 0.5 , 0.5 ,
0.5 , 0.5 , 0.5 , 0.5 , 1. ], dtype=np.float64),
},
'ap=0.20: cat_1 (3/19)': {
'precision': np.array([0. , 0.5 , 0.33333333, 0.25 , 0.2 ,
0.16666667, 0.14285714, 0.25 , 0.22222222, 0.2 ,
0.18181818, 0.16666667, 0.15384615, 0.14285714, 0.13333333,
0.15789474], dtype=np.float64),
'recall': np.array([0. , 0.33333333, 0.33333333, 0.33333333, 0.33333333,
0.33333333, 0.33333333, 0.66666667, 0.66666667, 0.66666667,
0.66666667, 0.66666667, 0.66666667, 0.66666667, 0.66666667,
1. ], dtype=np.float64),
},
'ap=0.54: cat_2 (8/19)': {
'precision': np.array([0. , 0.5 , 0.33333333, 0.5 , 0.6 ,
0.66666667, 0.71428571, 0.75 , 0.66666667, 0.6 ,
0.63636364, 0.58333333, 0.53846154, 0.5 , 0.46666667,
0.42105263], dtype=np.float64),
'recall': np.array([0. , 0.125, 0.125, 0.25 , 0.375, 0.5 , 0.625, 0.75 , 0.75 ,
0.75 , 0.875, 0.875, 0.875, 0.875, 0.875, 1. ], dtype=np.float64),
},
}
我想使用 seaborn 在一个图中绘制这些多条线,但为此我似乎需要将这些分组数据转换为一个长格式的 pandas 表。
import pandas as pd
longform = []
for key, subdata in data_groups.items():
subdata = pd.DataFrame.from_dict(subdata)
subdata['label'] = key
longform.append(subdata)
data = pd.concat(longform)
这有效地为列表中的每个项目复制了这个“标签”属性:
recall precision label
0 0.000000 0.000000 ap=0.54: cat_2 (8/19)
1 0.125000 0.500000 ap=0.54: cat_2 (8/19)
2 0.125000 0.333333 ap=0.54: cat_2 (8/19)
...
0 0.000000 0.000000 ap=0.20: cat_1 (3/19)
1 0.333333 0.500000 ap=0.20: cat_1 (3/19)
2 0.333333 0.333333 ap=0.20: cat_1 (3/19)
3 0.333333 0.250000 ap=0.20: cat_1 (3/19)
...
0 0.000000 0.000000 ap=0.16: cat_3 (4/19)
1 0.000000 0.000000 ap=0.16: cat_3 (4/19)
2 0.000000 0.000000 ap=0.16: cat_3 (4/19)
此时我可以绘制它:
import seaborn as sns
sns.lineplot(
data=data, x='recall', y='precision',
hue='label', style='label')
但我想知道是否有更有效的方法将预先分组的数据发送到 seaborn。我想避免重复“标签”属性,我想它必须有效地反转我刚刚执行的pd.concat 操作。
在 seaborn (https://seaborn.pydata.org/tutorial/data_structure.html) 接受的数据结构中,他们只提到了这种长格式(我非常了解)和宽格式数据(这对我来说意义不大)。
这个预先分组的数据不是宽格式的变体,对吧?我只是想验证执行额外的 concat 是当前执行此操作的唯一方法。
【问题讨论】:
-
您没有宽格式数据,因为您的字典是嵌套的,但您可以通过以下方式更轻松地获取长格式数据:
data_long = pd.concat({k: pd.DataFrame(v) for k, v in data_groups.items()}, names=["label"])