【发布时间】:2020-05-11 07:57:20
【问题描述】:
我有一个数据框,data,具有以下结构(我的实际上要大得多,但这只是为了说明目的):
a b c tag
A 3 2 4
B 2 1 3
A 5 3 3
A 4 3 2
B 2 4 3
A 3 5 2
B 4 1 1
C 2 3 1
C 1 3 4
B 5 2 4
我正在使用 scikit-learn 来拆分数据:
train, test = train_test_split(data, test_size=test_size)
但是,我想找到一种方法来拆分数据,以保证我在两组中每个标签至少有一行。对于示例数据集,这意味着有这样的东西(当然是打乱的):
火车
a b c tag
A 3 2 4
B 2 1 3
A 4 3 2
B 4 1 1
测试
similar but with the remaining elements (according to the proportion)
基本上,我希望在两组中都有整个范围/种类的标签。
我提前感谢所有帮助。
`
【问题讨论】:
-
检查函数内部参数
stratify =
标签: python pandas scikit-learn