问题描述
我正在使用mlr3软件包,并且想绘制不同模型的ROC曲线。如果我按照documentation中所述使用交叉验证,则效果很好,但是如果我使用“ holdout”进行重采样,则会收到错误Error: Invalid show_cb. Inconsistent with calc_avg of evalmod.
。
代码如下:
library("mlr3")
library("mlr3learners")
library("mlr3viz")
# one task only
tasks = lapply(c("german_credit"),tsk)
# get some learners and for all learners ...
# * predict probabilities
# * predict also on the training set
learners = c("classif.featureless","classif.rpart","classif.ranger","classif.kknn")
learners = lapply(learners,lrn,predict_type = "prob")
# compare via 3-fold cross validation
resamplings = rsmp("holdout",ratio = .8) # holdout instead of cv
# create a BenchmarkDesign object
design = benchmark_grid(tasks,learners,resamplings)
print(design)
bmr = benchmark(design)
autoplot(bmr,type = "roc")
感谢您的帮助, 马修(Mathieu)
解决方法
如果其他人遇到相同的问题,这里是一个解决方案。发生问题是因为参数calc_avg
在TRUE
中默认设置为precrec::evalmod()
,并且该函数按原样在mlr3viz::autoplot()
中使用。由于as_precrec()
返回的对象没有不同的dsid(在交叉验证的情况下,来自不同折叠的不同值,且具有保留,只有一个元素),因此无法对precrec
进行平均,因此错误(尽管从理论上讲可以)。
这是一段代码,可用于绘制具有保持状态(或任何其他类型的重采样)的ROC曲线。使用答案中的代码,我们可以执行以下操作:
roc_data <- evalmod(as_precrec(bmr),mode = "rocprc",calc_avg = FALSE) %>% # setting calc_avg to FALSE is critical
fortify() %>% # precrec objects have a fortify generic function
.[.$curvetype == "ROC",] # both roc and prc are returned
# Tracer les courbes
ggplot(
data = roc_data,mapping = aes(x = x,y = y,color = modname)
) +
geom_line()
此代码还具有成为ggplot
对象的优点,因此可以轻松地使用ggplot2
进行修改,而precrec::autoplot()
则不是这样。