【问题标题】:How to understand a function which splits the data如何理解拆分数据的函数
【发布时间】:2019-12-10 17:18:31
【问题描述】:

谁能帮我理解这个函数的作用?

我了解行打印,但之后我有点迷路了。从train_data开始。

def stratifiedShuffleSplit_data(X, y):
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
    for train_index, test_index in sss.split(X, y):
        print("len(TRAIN):", len(train_index), "len(TEST):", len(test_index))
        print("TRAIN:", train_index, "TEST:", test_index)

        train_data = [df.loc[ind] for ind in train_index]
        test_data = [df.loc[ind] for ind in test_index]
        save_datarows(train_data, datafile+".train")
        save_datarows(test_data, datafile+".test")

【问题讨论】:

  • 所以,您的主要疑问是“train_data = [df.loc[ind] for ind in train_index]”这一行,对吧?
  • 是的,最后两个

标签: python scikit-learn training-data k-fold


【解决方案1】:

假设你使用的是 Panda 包,

 pd.DataFrame.loc 

是一种基于位置的索引器 - 这是一个过于简化的版本。我将发布一些资源,可以帮助您更好地理解它。

train_data = [df.loc[ind] for ind in train_index]

在这里,您基本上遍历列表 ind 并存储各自的值 train_data test_data 的情况类似

我假设 save_datarows 是一个自定义函数,用于将 train_data 存储到扩展名为 .train 的文件中

希望这会有所帮助。

这是一个非常好的参考资料,可以进一步澄清:

Selection with .loc in python

https://www.geeksforgeeks.org/python-pandas-dataframe-loc/

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2018-01-18
    • 1970-01-01
    • 2021-01-08
    • 2011-03-19
    相关资源
    最近更新 更多