【发布时间】:2021-06-28 00:26:07
【问题描述】:
我使用resamples_fit 和work_flow() 和V-Fold Cross-Validation。
我的模型是逻辑回归。
如何使用 V 折交叉验证获得欧洲防风逻辑回归模型的系数?
如果我的V-Fold Cross-Validation v=5,我想得到5倍系数。
【问题讨论】:
标签: r tidymodels
我使用resamples_fit 和work_flow() 和V-Fold Cross-Validation。
我的模型是逻辑回归。
如何使用 V 折交叉验证获得欧洲防风逻辑回归模型的系数?
如果我的V-Fold Cross-Validation v=5,我想得到5倍系数。
【问题讨论】:
标签: r tidymodels
您通常不想使用fit_resamples() 来训练和保留五个模型; fit_resamples() 函数的主要用途是使用resampling to estimate performance。五个模型合身,然后扔掉。
但是,如果您确实有一些用例想要保留合适的模型,例如in this article,那么您将使用extract_model。
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
data(penguins)
set.seed(2021)
penguin_split <- penguins %>%
filter(!is.na(sex)) %>%
initial_split(strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)
penguin_folds <- vfold_cv(penguin_train, v = 5, strata = sex)
penguin_folds
#> # 5-fold cross-validation using stratification
#> # A tibble: 5 x 2
#> splits id
#> <list> <chr>
#> 1 <split [198/51]> Fold1
#> 2 <split [199/50]> Fold2
#> 3 <split [199/50]> Fold3
#> 4 <split [200/49]> Fold4
#> 5 <split [200/49]> Fold5
glm_spec <- logistic_reg() %>%
set_engine("glm")
glm_rs <- workflow() %>%
add_formula(sex ~ species + bill_length_mm + bill_depth_mm + body_mass_g) %>%
add_model(glm_spec) %>%
fit_resamples(
resamples = penguin_folds,
control = control_resamples(extract = extract_model, save_pred = TRUE)
)
既然您已经在重采样中使用了extract_model,它就在您的结果中,并且您拥有可用于每个折叠的模型。
glm_rs
#> # Resampling results
#> # 5-fold cross-validation using stratification
#> # A tibble: 5 x 6
#> splits id .metrics .notes .extracts .predictions
#> <list> <chr> <list> <list> <list> <list>
#> 1 <split [198/… Fold1 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [51 × …
#> 2 <split [199/… Fold2 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 3 <split [199/… Fold3 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 4 <split [200/… Fold4 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
#> 5 <split [200/… Fold5 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
glm_rs$.extracts[[1]]
#> # A tibble: 1 x 2
#> .extracts .config
#> <list> <chr>
#> 1 <glm> Preprocessor1_Model1
您可以使用tidyr 和broom 函数来获取系数,如果这是您正在寻找的。p>
glm_rs %>%
dplyr::select(id, .extracts) %>%
unnest(cols = .extracts) %>%
mutate(tidied = map(.extracts, tidy)) %>%
unnest(tidied)
#> # A tibble: 30 x 8
#> id .extracts .config term estimate std.error statistic p.value
#> <chr> <list> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 Fold1 <glm> Preprocesso… (Interce… -7.44e+1 12.6 -5.89 3.75e-9
#> 2 Fold1 <glm> Preprocesso… speciesC… -6.59e+0 1.82 -3.61 3.03e-4
#> 3 Fold1 <glm> Preprocesso… speciesG… -7.49e+0 2.54 -2.95 3.18e-3
#> 4 Fold1 <glm> Preprocesso… bill_len… 5.56e-1 0.151 3.67 2.40e-4
#> 5 Fold1 <glm> Preprocesso… bill_dep… 1.72e+0 0.424 4.06 4.83e-5
#> 6 Fold1 <glm> Preprocesso… body_mas… 5.88e-3 0.00130 4.51 6.44e-6
#> 7 Fold2 <glm> Preprocesso… (Interce… -6.87e+1 11.3 -6.06 1.37e-9
#> 8 Fold2 <glm> Preprocesso… speciesC… -5.59e+0 1.75 -3.20 1.39e-3
#> 9 Fold2 <glm> Preprocesso… speciesG… -7.61e+0 2.80 -2.71 6.65e-3
#> 10 Fold2 <glm> Preprocesso… bill_len… 4.88e-1 0.145 3.36 7.88e-4
#> # … with 20 more rows
由reprex package (v2.0.0) 于 2021-06-27 创建
【讨论】: