mlr3:如何使用mlr过滤训练数据集并将结果应用于模型训练?

问题描述

在mlr3中创建过滤器时,如何仅基于训练数据创建过滤器?

一旦创建了过滤器,您如何将其应用于建模过程,并对训练数据进行子集处理,以仅包括超过特定阈值的过滤器值?

library(mlr3)
library(mlr3filters)
library(mlr3learners)
library(tidyverse)


data(iris)
iris <- iris %>%
  select(-Species)
  
tsk <- mlr3::TaskRegr$new("iris",backend = iris,target = "Sepal.Length")

#split train and test
trn_ids <- sample(tsk$row_ids,floor(0.8 * length(tsk$row_ids)),F)
tst_ids <- setdiff(tsk$row_ids,trn_ids)

#create a filter
filter = flt("correlation",method = "spearman")

# Question 1: how to calculate the filter only for the train IDs?
filter$calculate(tsk)
print(filter)

# Question 2: how to only use only variables with X correlation or greater in training?
learner <- mlr_learners$get("regr.glmnet")
learner$train(tsk,row_ids = trn_ids)
prediction <- learner$predict(tsk,row_ids = tst_ids)
prediction$response

解决方法

可以使用mlr3pipelines将过滤器包装到学习器中。

mlr3画廊有一个示例here(“功能过滤”部分)。

基本方法是创建一个像这样的图形:

fpipe = po("filter",flt("mim"),filter.nfeat = 3) $>>$ lrn("regr.glmnet")

并将其包装在GraphLearner中:

lrnr = GraphLearner$new(fpipe)

lrnr现在可以像其他学习者一样使用,并且可以在训练学习者之前根据指定的过滤器在内部过滤功能。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...