在带有嵌套估算器的管道中使用GridSearchCV

问题描述

我尝试使用管道构建这样的模型:我想使用随机的Forst分类器预测多个输出。由于管道只允许最后一步成为分类器,因此我嵌套了管道。在没有GridSearch的情况下,该方法可以正常工作。

pipeline = Pipeline([
('vect',CountVectorizer()),('tfidf',TfidfTransformer()),('clf',MultiOutputClassifier(RandomForestClassifier(),n_jobs=-1)),])

现在,我尝试将多个参数传递给我的RF分类器,但是由于它是嵌套的,因此它将传递给MultiOutputClassifier,至少我认为是这样。

param_grid = { 
    'clf__n_estimators': [200,500],'clf__max_features': ['auto','sqrt','log2'],'clf__max_depth' : [4,5,6,7,8],'clf__criterion' :['gini','entropy']
}

cv = GridSearchCV(pipeline,param_grid=param_grid)

这将导致错误:ValueError:估算器的无效参数标准

是否可以将参数传递给我的RandomForestClassifier还是可以通过管道传递多个分类器?

解决方法

尝试一下:

pipeline = Pipeline([
('vect',CountVectorizer()),('tfidf',TfidfTransformer()),('clf',MultiOutputClassifier(RandomForestClassifier(),n_jobs=-1)),])

param_grid = { 
    'clf__estimator__n_estimators': [200,500],'clf__estimator__max_features': ['auto','sqrt','log2'],'clf__estimator__max_depth' : [4,5,6,7,8],'clf__estimator__criterion' :['gini','entropy']
}

cv = GridSearchCV(pipeline,param_grid=param_grid,n_jobs=2)

通常,您可以通过以下方式访问可调参数:

cv.get_params().keys()
dict_keys(['cv','error_score','estimator__memory','estimator__steps','estimator__verbose','estimator__vect','estimator__tfidf','estimator__clf','estimator__vect__analyzer','estimator__vect__binary','estimator__vect__decode_error','estimator__vect__dtype','estimator__vect__encoding','estimator__vect__input','estimator__vect__lowercase','estimator__vect__max_df','estimator__vect__max_features','estimator__vect__min_df','estimator__vect__ngram_range','estimator__vect__preprocessor','estimator__vect__stop_words','estimator__vect__strip_accents','estimator__vect__token_pattern','estimator__vect__tokenizer','estimator__vect__vocabulary','estimator__tfidf__norm','estimator__tfidf__smooth_idf','estimator__tfidf__sublinear_tf','estimator__tfidf__use_idf','estimator__clf__estimator__bootstrap','estimator__clf__estimator__ccp_alpha','estimator__clf__estimator__class_weight','estimator__clf__estimator__criterion','estimator__clf__estimator__max_depth','estimator__clf__estimator__max_features','estimator__clf__estimator__max_leaf_nodes','estimator__clf__estimator__max_samples','estimator__clf__estimator__min_impurity_decrease','estimator__clf__estimator__min_impurity_split','estimator__clf__estimator__min_samples_leaf','estimator__clf__estimator__min_samples_split','estimator__clf__estimator__min_weight_fraction_leaf','estimator__clf__estimator__n_estimators','estimator__clf__estimator__n_jobs','estimator__clf__estimator__oob_score','estimator__clf__estimator__random_state','estimator__clf__estimator__verbose','estimator__clf__estimator__warm_start','estimator__clf__estimator','estimator__clf__n_jobs','estimator','iid','n_jobs','param_grid','pre_dispatch','refit','return_train_score','scoring','verbose'])

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...