如何在我的 scikit 管道中正确应用标签编码器?

问题描述

我已经写了一个处理数据的管道,但是我的程序出现这个错误

AttributeError: 'numpy.ndarray' 对象没有属性 'fit'

我创建一个新类的原因是我试图直接在管道中实现 LabelEncoder 但它给了我一些不同的错误

我的班级是这样的:

class myLabelEncoder(BaseEstimator,TransformerMixin):
    
    def __init__(self):
        self.encoder = LabelEncoder()
        self.X = None
    def fit(self,X,y = None):
        self.X = X
        return self.encoder.fit(X)
    
    def transform(self,y = None):
        return self.encoder.transform(X)

这是管道:

#other transformers...

label_transformer = Pipeline(steps = [('imputer',SimpleImputer(missing_values = np.nan,strategy = 'most_frequent')),('label',myLabelEncoder)])


pipeline =ColumnTransformer(
        transformers = [('cat',categorical_transformer,cat_features),('num',numeric_transformer,num_features),('ord',ordinal_transformer,ord_features),('lab',label_transformer,label_features)]) 

logistic = LogisticRegression()

pca = PCA()

clf = Pipeline(steps = [('preprocessor',pipeline),('pca',pca),('classifier',logistic)])

#creating parameters for tuning pca and logistic regression

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=42)


clf.fit(X_train,y_train)

我认为问题出在类代码上,但我搜索了有关此错误的类似帖子,但我找不到解决方案。

以下是 5 个数据集样本:

X = 的 2 个样本行

 {'hotel': {54261: 'City Hotel',77908: 'City Hotel'},'meal': {54261: 'SC',77908: 'BB'},'lead_time': {54261: 49,77908: 9},'adr': {54261: 93.15,77908: 155.08},'stays_in_weekend_nights': {54261: 1,77908: 2},'stays_in_week_nights': {54261: 0,77908: 1},'adults': {54261: 2,'children': {54261: 0.0,77908: 0.0},'prevIoUs_cancellations': {54261: 0,77908: 0},'prevIoUs_bookings_not_canceled': {54261: 0,'total_of_special_requests': {54261: 0,'customer_type': {54261: 'Transient',77908: 'Transient'},'market_segment': {54261: 'Online TA',77908: 'Online TA'},'distribution_channel': {54261: 'TA/TO',77908: 'TA/TO'},'reserved_room_type': {54261: 'A',77908: 'A'},'assigned_room_type': {54261: 'A','arrival_date_month': {54261: 'July',77908: 'September'}}

y 的 2 个示例行:

{54261: 1,77908: 0,38772: 0,100452: 0,2318: 1}

感谢您抽出宝贵时间。

解决方法

您忘记在 label_transformer 中实例化您的标签编码器类:将 myLabelEncoder 替换为 myLabelEncoder()