【发布时间】:2021-06-03 00:32:39
【问题描述】:
我想使用 train_test_split 来创建我的数据的训练、验证和测试集。根据其他帖子,“简单”的方法是运行 train_test_split 两次。这是有道理的。 但是当我尝试这样做时,它会在第二次运行拆分时报告错误。 (Sklearn 版本:0.23.2) 我错过了什么吗?
from sklearn.datasets import make_classification
df = make_classification()
X = df[0]
y = df[1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
print(X_train.shape, y_train.shape)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)
输出:(80, 20) (80,)
它返回的错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-74-bf895d511057> in <module>
6 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
7 print(X_train.shape, y_train.shape)
----> 8 X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)
~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in train_test_split(*arrays, **options)
2150 random_state=random_state)
2151
-> 2152 train, test = next(cv.split(X=arrays[0], y=stratify))
2153
2154 return list(chain.from_iterable((_safe_indexing(a, train),
~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)
1338 to an integer.
1339 """
-> 1340 X, y, groups = indexable(X, y, groups)
1341 for train, test in self._iter_indices(X, y, groups):
1342 yield train, test
~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in indexable(*iterables)
290 """
291 result = [_make_indexable(X) for X in iterables]
--> 292 check_consistent_length(*result)
293 return result
294
~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in check_consistent_length(*arrays)
254 if len(uniques) > 1:
255 raise ValueError("Found input variables with inconsistent numbers of"
--> 256 " samples: %r" % [int(l) for l in lengths])
257
258
ValueError: Found input variables with inconsistent numbers of samples: [80, 100]
【问题讨论】:
标签: python scikit-learn