GLMNET输出产生平坦的MAE 对于结果列的简单转换,我们强烈建议这些操作在配方外进行

问题描述

我想知道我在以下recipe中的reprex中做错了什么。它似乎在所有超参数上产生了平坦的MAE。我仍在学习如何使用tidymodels,因此也许我在这里缺少学习技巧,但是伸出手去看看是否有经验的人可以在这里帮助我。谢谢!

library(tidymodels)
#> ── Attaching packages ──────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0      ✓ recipes   0.1.13
#> ✓ dials     0.0.8      ✓ rsample   0.0.7 
#> ✓ dplyr     1.0.2      ✓ tibble    3.0.3 
#> ✓ ggplot2   3.3.2      ✓ tidyr     1.1.2 
#> ✓ infer     0.5.3      ✓ tune      0.1.1 
#> ✓ modeldata 0.0.2      ✓ workflows 0.1.3 
#> ✓ parsnip   0.1.3      ✓ yardstick 0.0.7 
#> ✓ purrr     0.3.4
#> ── Conflicts ─────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

data(mtcars)

set.seed(101)
clean_split <- initial_split(mtcars,prop = 0.8)

train <- training(clean_split)
test <- testing(clean_split)


recipe <- recipe(mpg ~ .,data = train) %>%
       step_log(all_outcomes(),base = 10,skip = TRUE) %>%
       step_center(all_predictors(),-all_nominal()) %>%
       step_scale(all_predictors(),-all_nominal()) %>%
       step_other(all_nominal(),threshold = 0.01) %>%
       step_dummy(all_nominal()) %>% 
       step_zv(all_predictors(),-all_outcomes())

### Creating models
#### GLMNET
glmnet_model <- linear_reg(mode = "regression",penalty = tune(),mixture = tune()) %>%
       set_engine("glmnet")

glmnet_model
#> Linear Regression Model Specification (regression)
#> 
#> Main Arguments:
#>   penalty = tune()
#>   mixture = tune()
#> 
#> computational engine: glmnet

glmnet_params <- parameters(penalty(),mixture())
glmnet_params
#> Collection of 2 parameters for tuning
#> 
#>       id parameter type object class
#>  penalty        penalty    nparam[+]
#>  mixture        mixture    nparam[+]

set.seed(123)
glmnet_grid <- grid_max_entropy(glmnet_params,size = 20)
glmnet_grid
#> # A tibble: 20 x 2
#>     penalty mixture
#>       <dbl>   <dbl>
#>  1 2.94e- 1  0.702 
#>  2 1.48e- 4  0.996 
#>  3 1.60e- 1  0.444 
#>  4 5.86e- 1  0.975 
#>  5 1.69e- 9  0.0491
#>  6 1.10e- 5  0.699 
#>  7 2.76e- 2  0.988 
#>  8 4.95e- 8  0.753 
#>  9 1.07e- 5  0.382 
#> 10 7.87e- 8  0.331 
#> 11 4.07e- 1  0.180 
#> 12 1.70e- 3  0.590 
#> 13 2.52e-10  0.382 
#> 14 2.47e-10  0.666 
#> 15 2.31e- 9  0.921 
#> 16 1.31e- 7  0.546 
#> 17 1.49e- 6  0.973 
#> 18 1.28e- 3  0.0224
#> 19 7.49e- 7  0.0747
#> 20 2.37e- 3  0.351


### Creating a workflow for each model
glmnet_wfl <- workflow() %>%
       add_recipe(recipe) %>%
       add_model(glmnet_model)

### Creating cross-validation splits
cv_splits <- vfold_cv(train,v = 5)
cv_splits
#> #  5-fold cross-validation 
#> # A tibble: 5 x 2
#>   splits         id   
#>   <list>         <chr>
#> 1 <split [20/6]> Fold1
#> 2 <split [21/5]> Fold2
#> 3 <split [21/5]> Fold3
#> 4 <split [21/5]> Fold4
#> 5 <split [21/5]> Fold5


### Training and tuning models
#### GLMNET
glmnet_stage_1_cv_results_tbl <- tune_grid(
       glmnet_wfl,resamples = cv_splits,grid = glmnet_grid,metrics = metric_set(mae,mape,rmse,rsq),control = control_grid(verbose = TRUE)
)
#> i Fold1: recipe
#> ✓ Fold1: recipe
#> i Fold1: model  1/20
#> ✓ Fold1: model  1/20
#> i Fold1: model  1/20 (predictions)
#> i Fold1: model  2/20
#> ✓ Fold1: model  2/20
#> i Fold1: model  2/20 (predictions)
#> i Fold1: model  3/20
#> ✓ Fold1: model  3/20
#> i Fold1: model  3/20 (predictions)
#> i Fold1: model  4/20
#> ✓ Fold1: model  4/20
#> i Fold1: model  4/20 (predictions)
#> i Fold1: model  5/20
#> ✓ Fold1: model  5/20
#> i Fold1: model  5/20 (predictions)
#> i Fold1: model  6/20
#> ✓ Fold1: model  6/20
#> i Fold1: model  6/20 (predictions)
#> i Fold1: model  7/20
#> ✓ Fold1: model  7/20
#> i Fold1: model  7/20 (predictions)
#> i Fold1: model  8/20
#> ✓ Fold1: model  8/20
#> i Fold1: model  8/20 (predictions)
#> i Fold1: model  9/20
#> ✓ Fold1: model  9/20
#> i Fold1: model  9/20 (predictions)
#> i Fold1: model 10/20
#> ✓ Fold1: model 10/20
#> i Fold1: model 10/20 (predictions)
#> i Fold1: model 11/20
#> ✓ Fold1: model 11/20
#> i Fold1: model 11/20 (predictions)
#> i Fold1: model 12/20
#> ✓ Fold1: model 12/20
#> i Fold1: model 12/20 (predictions)
#> i Fold1: model 13/20
#> ✓ Fold1: model 13/20
#> i Fold1: model 13/20 (predictions)
#> i Fold1: model 14/20
#> ✓ Fold1: model 14/20
#> i Fold1: model 14/20 (predictions)
#> i Fold1: model 15/20
#> ✓ Fold1: model 15/20
#> i Fold1: model 15/20 (predictions)
#> i Fold1: model 16/20
#> ✓ Fold1: model 16/20
#> i Fold1: model 16/20 (predictions)
#> i Fold1: model 17/20
#> ✓ Fold1: model 17/20
#> i Fold1: model 17/20 (predictions)
#> i Fold1: model 18/20
#> ✓ Fold1: model 18/20
#> i Fold1: model 18/20 (predictions)
#> i Fold1: model 19/20
#> ✓ Fold1: model 19/20
#> i Fold1: model 19/20 (predictions)
#> i Fold1: model 20/20
#> ✓ Fold1: model 20/20
#> i Fold1: model 20/20 (predictions)
#> ! Fold1: internal: A correlation computation is required,but `estimate` is const...
#> ✓ Fold1: internal
#> i Fold2: recipe
#> ✓ Fold2: recipe
#> i Fold2: model  1/20
#> ✓ Fold2: model  1/20
#> i Fold2: model  1/20 (predictions)
#> i Fold2: model  2/20
#> ✓ Fold2: model  2/20
#> i Fold2: model  2/20 (predictions)
#> i Fold2: model  3/20
#> ✓ Fold2: model  3/20
#> i Fold2: model  3/20 (predictions)
#> i Fold2: model  4/20
#> ✓ Fold2: model  4/20
#> i Fold2: model  4/20 (predictions)
#> i Fold2: model  5/20
#> ✓ Fold2: model  5/20
#> i Fold2: model  5/20 (predictions)
#> i Fold2: model  6/20
#> ✓ Fold2: model  6/20
#> i Fold2: model  6/20 (predictions)
#> i Fold2: model  7/20
#> ✓ Fold2: model  7/20
#> i Fold2: model  7/20 (predictions)
#> i Fold2: model  8/20
#> ✓ Fold2: model  8/20
#> i Fold2: model  8/20 (predictions)
#> i Fold2: model  9/20
#> ✓ Fold2: model  9/20
#> i Fold2: model  9/20 (predictions)
#> i Fold2: model 10/20
#> ✓ Fold2: model 10/20
#> i Fold2: model 10/20 (predictions)
#> i Fold2: model 11/20
#> ✓ Fold2: model 11/20
#> i Fold2: model 11/20 (predictions)
#> i Fold2: model 12/20
#> ✓ Fold2: model 12/20
#> i Fold2: model 12/20 (predictions)
#> i Fold2: model 13/20
#> ✓ Fold2: model 13/20
#> i Fold2: model 13/20 (predictions)
#> i Fold2: model 14/20
#> ✓ Fold2: model 14/20
#> i Fold2: model 14/20 (predictions)
#> i Fold2: model 15/20
#> ✓ Fold2: model 15/20
#> i Fold2: model 15/20 (predictions)
#> i Fold2: model 16/20
#> ✓ Fold2: model 16/20
#> i Fold2: model 16/20 (predictions)
#> i Fold2: model 17/20
#> ✓ Fold2: model 17/20
#> i Fold2: model 17/20 (predictions)
#> i Fold2: model 18/20
#> ✓ Fold2: model 18/20
#> i Fold2: model 18/20 (predictions)
#> i Fold2: model 19/20
#> ✓ Fold2: model 19/20
#> i Fold2: model 19/20 (predictions)
#> i Fold2: model 20/20
#> ✓ Fold2: model 20/20
#> i Fold2: model 20/20 (predictions)
#> ! Fold2: internal: A correlation computation is required,but `estimate` is const...
#> ✓ Fold2: internal
#> i Fold3: recipe
#> ✓ Fold3: recipe
#> i Fold3: model  1/20
#> ✓ Fold3: model  1/20
#> i Fold3: model  1/20 (predictions)
#> i Fold3: model  2/20
#> ✓ Fold3: model  2/20
#> i Fold3: model  2/20 (predictions)
#> i Fold3: model  3/20
#> ✓ Fold3: model  3/20
#> i Fold3: model  3/20 (predictions)
#> i Fold3: model  4/20
#> ✓ Fold3: model  4/20
#> i Fold3: model  4/20 (predictions)
#> i Fold3: model  5/20
#> ✓ Fold3: model  5/20
#> i Fold3: model  5/20 (predictions)
#> i Fold3: model  6/20
#> ✓ Fold3: model  6/20
#> i Fold3: model  6/20 (predictions)
#> i Fold3: model  7/20
#> ✓ Fold3: model  7/20
#> i Fold3: model  7/20 (predictions)
#> i Fold3: model  8/20
#> ✓ Fold3: model  8/20
#> i Fold3: model  8/20 (predictions)
#> i Fold3: model  9/20
#> ✓ Fold3: model  9/20
#> i Fold3: model  9/20 (predictions)
#> i Fold3: model 10/20
#> ✓ Fold3: model 10/20
#> i Fold3: model 10/20 (predictions)
#> i Fold3: model 11/20
#> ✓ Fold3: model 11/20
#> i Fold3: model 11/20 (predictions)
#> i Fold3: model 12/20
#> ✓ Fold3: model 12/20
#> i Fold3: model 12/20 (predictions)
#> i Fold3: model 13/20
#> ✓ Fold3: model 13/20
#> i Fold3: model 13/20 (predictions)
#> i Fold3: model 14/20
#> ✓ Fold3: model 14/20
#> i Fold3: model 14/20 (predictions)
#> i Fold3: model 15/20
#> ✓ Fold3: model 15/20
#> i Fold3: model 15/20 (predictions)
#> i Fold3: model 16/20
#> ✓ Fold3: model 16/20
#> i Fold3: model 16/20 (predictions)
#> i Fold3: model 17/20
#> ✓ Fold3: model 17/20
#> i Fold3: model 17/20 (predictions)
#> i Fold3: model 18/20
#> ✓ Fold3: model 18/20
#> i Fold3: model 18/20 (predictions)
#> i Fold3: model 19/20
#> ✓ Fold3: model 19/20
#> i Fold3: model 19/20 (predictions)
#> i Fold3: model 20/20
#> ✓ Fold3: model 20/20
#> i Fold3: model 20/20 (predictions)
#> ! Fold3: internal: A correlation computation is required,but `estimate` is const...
#> ✓ Fold3: internal
#> i Fold4: recipe
#> ✓ Fold4: recipe
#> i Fold4: model  1/20
#> ✓ Fold4: model  1/20
#> i Fold4: model  1/20 (predictions)
#> i Fold4: model  2/20
#> ✓ Fold4: model  2/20
#> i Fold4: model  2/20 (predictions)
#> i Fold4: model  3/20
#> ✓ Fold4: model  3/20
#> i Fold4: model  3/20 (predictions)
#> i Fold4: model  4/20
#> ✓ Fold4: model  4/20
#> i Fold4: model  4/20 (predictions)
#> i Fold4: model  5/20
#> ✓ Fold4: model  5/20
#> i Fold4: model  5/20 (predictions)
#> i Fold4: model  6/20
#> ✓ Fold4: model  6/20
#> i Fold4: model  6/20 (predictions)
#> i Fold4: model  7/20
#> ✓ Fold4: model  7/20
#> i Fold4: model  7/20 (predictions)
#> i Fold4: model  8/20
#> ✓ Fold4: model  8/20
#> i Fold4: model  8/20 (predictions)
#> i Fold4: model  9/20
#> ✓ Fold4: model  9/20
#> i Fold4: model  9/20 (predictions)
#> i Fold4: model 10/20
#> ✓ Fold4: model 10/20
#> i Fold4: model 10/20 (predictions)
#> i Fold4: model 11/20
#> ✓ Fold4: model 11/20
#> i Fold4: model 11/20 (predictions)
#> i Fold4: model 12/20
#> ✓ Fold4: model 12/20
#> i Fold4: model 12/20 (predictions)
#> i Fold4: model 13/20
#> ✓ Fold4: model 13/20
#> i Fold4: model 13/20 (predictions)
#> i Fold4: model 14/20
#> ✓ Fold4: model 14/20
#> i Fold4: model 14/20 (predictions)
#> i Fold4: model 15/20
#> ✓ Fold4: model 15/20
#> i Fold4: model 15/20 (predictions)
#> i Fold4: model 16/20
#> ✓ Fold4: model 16/20
#> i Fold4: model 16/20 (predictions)
#> i Fold4: model 17/20
#> ✓ Fold4: model 17/20
#> i Fold4: model 17/20 (predictions)
#> i Fold4: model 18/20
#> ✓ Fold4: model 18/20
#> i Fold4: model 18/20 (predictions)
#> i Fold4: model 19/20
#> ✓ Fold4: model 19/20
#> i Fold4: model 19/20 (predictions)
#> i Fold4: model 20/20
#> ✓ Fold4: model 20/20
#> i Fold4: model 20/20 (predictions)
#> ! Fold4: internal: A correlation computation is required,but `estimate` is const...
#> ✓ Fold4: internal
#> i Fold5: recipe
#> ✓ Fold5: recipe
#> i Fold5: model  1/20
#> ✓ Fold5: model  1/20
#> i Fold5: model  1/20 (predictions)
#> i Fold5: model  2/20
#> ✓ Fold5: model  2/20
#> i Fold5: model  2/20 (predictions)
#> i Fold5: model  3/20
#> ✓ Fold5: model  3/20
#> i Fold5: model  3/20 (predictions)
#> i Fold5: model  4/20
#> ✓ Fold5: model  4/20
#> i Fold5: model  4/20 (predictions)
#> i Fold5: model  5/20
#> ✓ Fold5: model  5/20
#> i Fold5: model  5/20 (predictions)
#> i Fold5: model  6/20
#> ✓ Fold5: model  6/20
#> i Fold5: model  6/20 (predictions)
#> i Fold5: model  7/20
#> ✓ Fold5: model  7/20
#> i Fold5: model  7/20 (predictions)
#> i Fold5: model  8/20
#> ✓ Fold5: model  8/20
#> i Fold5: model  8/20 (predictions)
#> i Fold5: model  9/20
#> ✓ Fold5: model  9/20
#> i Fold5: model  9/20 (predictions)
#> i Fold5: model 10/20
#> ✓ Fold5: model 10/20
#> i Fold5: model 10/20 (predictions)
#> i Fold5: model 11/20
#> ✓ Fold5: model 11/20
#> i Fold5: model 11/20 (predictions)
#> i Fold5: model 12/20
#> ✓ Fold5: model 12/20
#> i Fold5: model 12/20 (predictions)
#> i Fold5: model 13/20
#> ✓ Fold5: model 13/20
#> i Fold5: model 13/20 (predictions)
#> i Fold5: model 14/20
#> ✓ Fold5: model 14/20
#> i Fold5: model 14/20 (predictions)
#> i Fold5: model 15/20
#> ✓ Fold5: model 15/20
#> i Fold5: model 15/20 (predictions)
#> i Fold5: model 16/20
#> ✓ Fold5: model 16/20
#> i Fold5: model 16/20 (predictions)
#> i Fold5: model 17/20
#> ✓ Fold5: model 17/20
#> i Fold5: model 17/20 (predictions)
#> i Fold5: model 18/20
#> ✓ Fold5: model 18/20
#> i Fold5: model 18/20 (predictions)
#> i Fold5: model 19/20
#> ✓ Fold5: model 19/20
#> i Fold5: model 19/20 (predictions)
#> i Fold5: model 20/20
#> ✓ Fold5: model 20/20
#> i Fold5: model 20/20 (predictions)
#> ! Fold5: internal: A correlation computation is required,but `estimate` is const...
#> ✓ Fold5: internal
glmnet_stage_1_cv_results_tbl %>% show_best("mae",n = 10)
#> # A tibble: 10 x 8
#>     penalty mixture .metric .estimator  mean     n std_err .config
#>       <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>  
#>  1 4.95e- 8  0.753  mae     standard    19.1     5    1.56 Model15
#>  2 1.49e- 6  0.973  mae     standard    19.1     5    1.56 Model17
#>  3 2.31e- 9  0.921  mae     standard    19.1     5    1.56 Model16
#>  4 1.10e- 5  0.699  mae     standard    19.1     5    1.56 Model13
#>  5 1.31e- 7  0.546  mae     standard    19.1     5    1.56 Model10
#>  6 2.47e-10  0.666  mae     standard    19.1     5    1.56 Model12
#>  7 7.87e- 8  0.331  mae     standard    19.1     5    1.56 Model05
#>  8 1.07e- 5  0.382  mae     standard    19.1     5    1.56 Model08
#>  9 2.52e-10  0.382  mae     standard    19.1     5    1.56 Model07
#> 10 7.49e- 7  0.0747 mae     standard    19.1     5    1.56 Model03

reprex package(v0.3.0)于2020-09-04创建

解决方法

在配方中转换结果可能非常棘手。假设您确实需要执行以下操作:

step_log(mpg,base = 10)

当未知结果时将配方应用于新数据时,这将导致失败。由于燃油效率是我们试图预测的结果,因此该变量的数据中可能不会有任何列。实际上,为避免信息泄漏,在使用tune_grid()进行调整期间,用于训练每个模型的数据与用于评估每个模型的数据是隔离的。这意味着训练集和任何结果列在预测时都无法使用。

对于结果列的简单转换,我们强烈建议这些操作在配方外进行

library(tidymodels)

log_car <- mtcars %>%
  mutate(mpg = log10(mpg))

set.seed(101)
clean_split <- initial_split(log_car,prop = 0.8)

train <- training(clean_split)
test <- testing(clean_split)


recipe <- recipe(mpg ~ .,data = train) %>%
  step_other(all_nominal(),threshold = 0.01) %>%
  step_dummy(all_nominal()) %>% 
  step_normalize(all_predictors(),-all_nominal()) %>%
  step_zv(all_predictors(),-all_outcomes())

glmnet_model <- linear_reg(mode = "regression",penalty = tune(),mixture = tune()) %>%
  set_engine("glmnet")

glmnet_model
#> Linear Regression Model Specification (regression)
#> 
#> Main Arguments:
#>   penalty = tune()
#>   mixture = tune()
#> 
#> Computational engine: glmnet

set.seed(123)
glmnet_grid <- grid_max_entropy(penalty(),mixture(),size = 20)
glmnet_grid
#> # A tibble: 20 x 2
#>     penalty mixture
#>       <dbl>   <dbl>
#>  1 2.94e- 1  0.702 
#>  2 1.48e- 4  0.996 
#>  3 1.60e- 1  0.444 
#>  4 5.86e- 1  0.975 
#>  5 1.69e- 9  0.0491
#>  6 1.10e- 5  0.699 
#>  7 2.76e- 2  0.988 
#>  8 4.95e- 8  0.753 
#>  9 1.07e- 5  0.382 
#> 10 7.87e- 8  0.331 
#> 11 4.07e- 1  0.180 
#> 12 1.70e- 3  0.590 
#> 13 2.52e-10  0.382 
#> 14 2.47e-10  0.666 
#> 15 2.31e- 9  0.921 
#> 16 1.31e- 7  0.546 
#> 17 1.49e- 6  0.973 
#> 18 1.28e- 3  0.0224
#> 19 7.49e- 7  0.0747
#> 20 2.37e- 3  0.351

glmnet_wfl <- workflow() %>%
  add_recipe(recipe) %>%
  add_model(glmnet_model)


cv_splits <- vfold_cv(train,v = 5)
cv_splits
#> #  5-fold cross-validation 
#> # A tibble: 5 x 2
#>   splits         id   
#>   <list>         <chr>
#> 1 <split [20/6]> Fold1
#> 2 <split [21/5]> Fold2
#> 3 <split [21/5]> Fold3
#> 4 <split [21/5]> Fold4
#> 5 <split [21/5]> Fold5

glmnet_stage_1_cv_results_tbl <- tune_grid(
  glmnet_wfl,resamples = cv_splits,grid = glmnet_grid,metrics = metric_set(mae)
)
#> Loading required package: Matrix
#> 
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#> 
#>     expand,pack,unpack
#> Loaded glmnet 4.0-2

glmnet_stage_1_cv_results_tbl %>% 
  show_best("mae",n = 10)
#> # A tibble: 10 x 8
#>     penalty mixture .metric .estimator   mean     n std_err .config
#>       <dbl>   <dbl> <chr>   <chr>       <dbl> <int>   <dbl> <chr>  
#>  1 1.70e- 3  0.590  mae     standard   0.0396     5 0.00583 Model11
#>  2 2.37e- 3  0.351  mae     standard   0.0400     5 0.00597 Model06
#>  3 1.28e- 3  0.0224 mae     standard   0.0433     5 0.00592 Model01
#>  4 1.48e- 4  0.996  mae     standard   0.0462     5 0.00538 Model20
#>  5 2.76e- 2  0.988  mae     standard   0.0465     5 0.00699 Model19
#>  6 1.69e- 9  0.0491 mae     standard   0.0470     5 0.00536 Model02
#>  7 7.49e- 7  0.0747 mae     standard   0.0474     5 0.00539 Model03
#>  8 2.47e-10  0.666  mae     standard   0.0476     5 0.00545 Model12
#>  9 1.10e- 5  0.699  mae     standard   0.0476     5 0.00546 Model13
#> 10 2.31e- 9  0.921  mae     standard   0.0476     5 0.00547 Model16

reprex package(v0.3.0.9001)于2020-09-06创建