管道GridSearchCV,不同步骤中的相应参数

问题描述

我正在尝试在管道中进行一些超参数调整,并具有以下设置:

model = KerasClassifier(build_fn = create_model,epochs = 5)  
pipeline = Pipeline(steps =[('Tokenizepadder',TokenizePadding()),('NN',model)] )

在Tokenizepadder和我的神经网络中都有一个变量“ maxlen”(对于神经网络,它称为max_length,我担心将它们命名为相同的名称会在以后的代码中引起错误)。当我尝试执行网格搜索时,我正在努力使这些值对应。如果我分别对这些值进行网格搜索,它们将不会满足要求,并且输入数据将与神经网络不匹配。

简而言之,我想做些类似的事情:


pipeline = Pipeline(steps =[('Tokenizepadder',KerasClassifier(build_fn = create_model,epochs = 5,max_length = pipeline.get_params()['Tokenizepadder__maxlen']))] )

因此,当我在网格中搜索参数'Tokenizepadder__maxlen'时,它将把'NN__max_length'值更改为相同的值。

解决方法

也许您可以更改分类器和标记器,以传递max_len参数。然后,仅使用标记符max_len参数进行网格搜索。 不是最干净的方法,但是可能会做到。

from sklearn.base import BaseEstimator,TransformerMixin,EstimatorMixin
class TokeinizePadding(BaseEstimator,TransformerMixin):
    def __init__(self,max_len,...):
        self.max_len = max_len
        ...

    def fit(self,X,y=None):
        ...
        return self
   
    def transform(self,y=None):
        data = ... # your stuff
        return {"array": data,"max_len": self.max_len}
 

class KerasClassifier(...):
    ...
    def fit(data,y):
        self.max_len = data["max_len"]
        self.build_model()
        X = data["array"]
        ... # your stuff

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...