【问题标题】:Is there a way to solve this error concerning StratifiedShuffleSplit?有没有办法解决这个关于 StratifiedShuffleSplit 的错误?
【发布时间】:2021-06-05 10:29:30
【问题描述】:

我是 ML 的新手,我一直在尝试 udacity ML 项目。但是,我遇到了一个错误,我很难解决。代码看起来不错,但我似乎无法遍历数据。我知道这与所做的新 StratifiedShuffleSplit 更改有关。代码挂了。

def Stratified_Shuffle_Split(X,y,num_test):
    sss = StratifiedShuffleSplit(y, 1, test_size=num_test, random_state = None)
    for train, test in sss:
        X_train, X_test = X.iloc[train], X.iloc[test]
        y_train, y_test = y.iloc[train], y.iloc[test]
    return X_train, X_test, y_train, y_test

# First, decide how many training vs test samples you want
num_all = student_data.shape[0]  # same as len(student_data)
num_train = round(num_all*0.75)  # about 75% of the data
num_test = num_all - num_train
#print(num_test)

y = student_data['passed'] # identify target variable
X_train, X_test, y_train, y_test = Stratified_Shuffle_Split(X_all, y, num_test)

print("Training Set: {0:.2f} Samples".format(X_train.shape[0]))
print("Testing Set: {0:.2f} Samples".format(X_test.shape[0]))

我的错误是这样的

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-20-2147158fcaf2> in <module>
     13 
     14 y = student_data['passed'] # identify target variable
---> 15 X_train, X_test, y_train, y_test = Stratified_Shuffle_Split(X_all, y, num_test)
     16 
     17 print("Training Set: {0:.2f} Samples".format(X_train.shape[0]))

<ipython-input-20-2147158fcaf2> in Stratified_Shuffle_Split(X, y, num_test)
      1 def Stratified_Shuffle_Split(X,y,num_test):
      2     sss = StratifiedShuffleSplit(y, 1, test_size=num_test, random_state = None)
----> 3     for train, test in sss:
      4         X_train, X_test = X.iloc[train], X.iloc[test]
      5         y_train, y_test = y.iloc[train], y.iloc[test]

TypeError: 'StratifiedShuffleSplit' object is not iterable
'''

【问题讨论】:

    标签: python scikit-learn train-test-split


    【解决方案1】:

    根据documentation,需要在StratifiedShuffleSplit上运行.split()函数。您需要.split() 来生成您尝试切片的索引。所以这部分可能是:

    def Stratified_Shuffle_Split(X,y,num_test):
         sss = StratifiedShuffleSplit(y, 1, test_size=num_test, random_state = None)
         for train, test in sss.split(X, y):
            X_train, X_test = X.iloc[train], X.iloc[test]
            y_train, y_test = y.iloc[train], y.iloc[test]
         return X_train, X_test, y_train, y_test
    

    我也不确定是否需要定义一个新函数,StratifiedShuffleSplit 已经是一个现成的函数,可以用你拥有的这个 for 循环来做你想做的事情。

    【讨论】:

      猜你喜欢
      • 2020-12-13
      • 1970-01-01
      • 2021-05-28
      • 1970-01-01
      • 2020-02-25
      • 2020-02-18
      • 2021-05-07
      • 1970-01-01
      • 2021-02-11
      相关资源
      最近更新 更多