使用mlr3 PipeOps创建具有不同数据子集的分支

问题描述

我想使用mlr3在不同数据子集上训练模型,而我想知道是否有一种方法可以在管道中对不同数据子集训练模型。

我想做的事情与R for Data Science - Chapter 25: Many models中的示例相似。假设我们使用相同的数据集gapminder,该数据集包含世界各国的不同变量,例如GDP和预期寿命。如果我想训练每个国家的预期寿命模型,是否有一种简单的方法可以使用mlr3创建这样的渠道?

理想情况下,我想使用mlr3pipelines在图形中为每个子集(例如,每个国家/地区有一个单独的分支)创建一个分支,并在最后添加模型。因此,最终图将在单个节点上开始,并在末端节点上有n个受过训练的学习者,数据集中每个组(即国家/地区)一个,或者是一个汇总结果的最终节点。我也希望它能用于新数据,例如,如果我们在2020年将来获得新数据,我希望它能够使用针对特定国家/地区训练的模型为每个国家/地区创建预测。

我发现的所有mlr3示例似乎都涉及整个数据集的模型,或者对模型进行了训练集中的所有组的训练。

当前,我只是为每组数据手动创建一个单独的任务,但是将数据子集步骤合并到建模管道中会很好。

解决方法

如果您具有以下两个包中的功能,这将有所帮助:dplyrtidyr。以下代码显示了如何按国家/地区训练多个模型:

library(dplyr)
library(tidyr)

df <- gapminder::gapminder

by_country <- 
  df %>% 
  nest(data = -c(continent,country)) %>% 
  mutate(model = lapply(data,learn))

请注意,learn是将单个数据帧作为其输入的函数。稍后我将向您展示如何定义该函数。现在您需要知道从该管道返回的数据帧如下:

# A tibble: 142 x 4
   country     continent data              model     
   <fct>       <fct>     <list>            <list>    
 1 Afghanistan Asia      <tibble [12 x 4]> <LrnrRgrR>
 2 Albania     Europe    <tibble [12 x 4]> <LrnrRgrR>
 3 Algeria     Africa    <tibble [12 x 4]> <LrnrRgrR>
 4 Angola      Africa    <tibble [12 x 4]> <LrnrRgrR>
 5 Argentina   Americas  <tibble [12 x 4]> <LrnrRgrR>
 6 Australia   Oceania   <tibble [12 x 4]> <LrnrRgrR>
 7 Austria     Europe    <tibble [12 x 4]> <LrnrRgrR>
 8 Bahrain     Asia      <tibble [12 x 4]> <LrnrRgrR>
 9 Bangladesh  Asia      <tibble [12 x 4]> <LrnrRgrR>
10 Belgium     Europe    <tibble [12 x 4]> <LrnrRgrR>

要定义learn函数,请按照mlr3网站上提供的步骤进行操作。功能是

learn <- function(df) {
  # I create a regression task as the target `lifeExp` is a numeric variable.
  task <- mlr3::TaskRegr$new(id = "gapminder",backend = df,target = "lifeExp")
  # define the learner you want to use.
  learner <- mlr3::lrn("regr.rpart")
  # train your dataset and return the trained model as an output
  learner$train(task)
}

我希望这可以解决您的问题。

请考虑以下步骤来训练模型并预测每个国家/地区的结果。

create_task <- function(id,df,ratio) {
  train <- sample(nrow(df),ratio * nrow(df))
  task <- mlr3::TaskRegr$new(id = as.character(id),target = "lifeExp")
  list(task = task,train = train,test = seq_len(nrow(df))[-train])
}

model_task <- function(learner,task_list) {
  learner$train(task_list[["task"]],row_ids = task_list[["train"]])
}

predict_result <- function(learner,task_list) {
  learner$predict(task_list[["task"]],row_ids = task_list[["test"]])
}

by_country <- 
  df %>% 
  nest(data = -c(continent,country)) %>% 
  mutate(
    task_list = Map(create_task,country,data,0.8),learner = list(mlr3::lrn("regr.rpart"))
  ) %>% 
  within({
    Map(model_task,learner,task_list)
    prediction <- Map(predict_result,task_list)
  })

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...