【问题标题】:Sklearn train_test_split reporting error when running twicesklearn train_test_split 运行两次报错
【发布时间】:2021-06-03 00:32:39
【问题描述】:

我想使用 train_test_split 来创建我的数据的训练、验证和测试集。根据其他帖子,“简单”的方法是运行 train_test_split 两次。这是有道理的。 但是当我尝试这样做时,它会在第二次运行拆分时报告错误。 (Sklearn 版本:0.23.2) 我错过了什么吗?

from sklearn.datasets import make_classification
df = make_classification()
X = df[0]
y = df[1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
print(X_train.shape, y_train.shape)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)

输出:(80, 20) (80,)

它返回的错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-74-bf895d511057> in <module>
      6 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
      7 print(X_train.shape, y_train.shape)
----> 8 X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y)

~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in train_test_split(*arrays, **options)
   2150                      random_state=random_state)
   2151 
-> 2152         train, test = next(cv.split(X=arrays[0], y=stratify))
   2153 
   2154     return list(chain.from_iterable((_safe_indexing(a, train),

~\anaconda3\envs\trading\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups)
   1338         to an integer.
   1339         """
-> 1340         X, y, groups = indexable(X, y, groups)
   1341         for train, test in self._iter_indices(X, y, groups):
   1342             yield train, test

~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in indexable(*iterables)
    290     """
    291     result = [_make_indexable(X) for X in iterables]
--> 292     check_consistent_length(*result)
    293     return result
    294 

~\anaconda3\envs\trading\lib\site-packages\sklearn\utils\validation.py in check_consistent_length(*arrays)
    254     if len(uniques) > 1:
    255         raise ValueError("Found input variables with inconsistent numbers of"
--> 256                          " samples: %r" % [int(l) for l in lengths])
    257 
    258 

ValueError: Found input variables with inconsistent numbers of samples: [80, 100]

【问题讨论】:

    标签: python scikit-learn


    【解决方案1】:

    问题出在stratify 参数中。您正在使用stratify=y,您必须使用stratify=y_train。如果不是,您将达到不一致的样本数错误。试试下面的代码:

    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train)
    

    【讨论】:

    • 天啊,当然可以。出于某种原因,我将 stratify=y 设为“是”,我想分层。
    猜你喜欢
    • 2018-06-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-03-25
    • 2021-08-16
    • 2019-10-03
    • 2020-04-24
    • 1970-01-01
    相关资源
    最近更新 更多