问题描述
在创建一个决策树函数后,我决定检查树的准确度,并确认如果我用相同的数据制作另一棵树,至少第一次分割是相同的
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import os
from sklearn import tree
from sklearn import preprocessing
import sys
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
.....
def desicion_tree(data_set:pd.DataFrame,val_1 : str,val_2 : str):
#Encoder -- > fit doesn't accept strings
feature_cols = data_set.columns[0:-1]
X = data_set[feature_cols] # Independent variables
y = data_set.Mut #class
y = y.to_list()
le = preprocessing.LabelBinarizer()
y = le.fit_transform(y)
# Split data set into training set and test set
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.25,random_state=1) # 75%
# Create Decision Tree classifer object
clf = DecisionTreeClassifier(max_depth= 4,criterion= 'entropy')
# Train Decision Tree Classifer
clf.fit(X_train,y_train)
# Predict the response for test dataset
y_pred = clf.predict(X_test)
#Perform cross validation
for i in range(2,8):
plt.figure(figsize=(14,7))
# Perform Kfold cross validation
#cv = ShuffleSplit(test_size=0.25,random_state=0)
kf = KFold(n_splits=5,shuffle= True)
scores = cross_val_score(estimator=clf,X=X,y=y,n_jobs=4,cv=kf)
print("%0.2f accuracy with a standard deviation of %0.2f" % (scores.mean(),scores.std()))
tree.plot_tree(clf,filled = True,feature_names=feature_cols,class_names=[val_1,val_2])
plt.show()
desicion_tree(car_rep_sep_20,'Categorial','Non categorial')
Down ,我编写了一个循环,以便使用 Kfold 重新创建具有拆分值的树。准确率在变化(大约 90%)但树是一样的,我哪里弄错了?
解决方法
cross_val_score
克隆估计器,以便在各种折叠上进行拟合和评分,因此 clf
对象保持与循环前将其拟合到整个数据集时相同,因此绘制的树是那个树,而不是任何交叉验证的树。
为了得到您想要的结果,我认为您可以将 cross_validate
与选项 return_estimator=True
一起使用。如果您的 cv 对象具有所需的拆分数量,您也不应该需要循环:
kf = KFold(n_splits=5,shuffle=True)
cv_results = cross_validate(
estimator=clf,X=X,y=y,n_jobs=4,cv=kf,return_estimator=True,)
print("%0.2f accuracy with a standard deviation of %0.2f" % (
cv_results['test_score'].mean(),cv_results['test_score'].std(),))
for est in cv_results['estimator']:
tree.plot_tree(est,filled=True,feature_names=feature_cols,class_names=[val_1,val_2])
plt.show();
或者,在折叠(或其他 cv 迭代)上手动循环,拟合模型并在循环中绘制其树。