决策树回归模型以最高的精度获得模型的max_depth值

问题描述

使用默认参数从X_train集和Y_train标签构建决策树回归模型。将模型命名为dt_reg。

在训练数据集上评估模型准确性并打印其分数。

在测试数据集上评估模型的准确性并打印其分数。

预测X_test集合的前两个样本的房价并打印出来。(提示:使用predict()函数)

在X_train数据和Y_train标签上安装多个决策树回归变量,其max_depth参数值从2更改为5。

根据测试数据集评估每个模型的准确性。

提示:利用for循环

以最高的精度打印模型的max_depth值。

import sklearn.datasets as datasets
from sklearn.model_selection import train_test_split 
from sklearn.tree import DecisionTreeRegressor
import numpy as np
np.random.seed(100) 
boston = datasets.load_boston()
X_train,X_test,Y_train,Y_test = train_test_split(boston.data,boston.target,random_state=30)
print(X_train.shape)
print(X_test.shape)

dt_reg = DecisionTreeRegressor()   
dt_reg = dt_reg.fit(X_train,Y_train) 
print(dt_reg.score(X_train,Y_train))
print(dt_reg.score(X_test,Y_test))
y_pred=dt_reg.predict(X_test[:2])
print(y_pred)

我想以最高的精度打印模型的max_depth值。但是壁画没有提交让我知道什么是错误。

max_reg = None
max_score = 0  
t=()
for m in range(2,6) :
    rf_reg = DecisionTreeRegressor(max_depth=m)
    rf_reg = rf_reg.fit(X_train,Y_train) 
    rf_reg_score = rf_reg.score(X_test,Y_test)
    print (m,rf_reg_score,max_score) 
    if rf_reg_score > max_score :
        max_score = rf_reg_score
        max_reg = rf_reg
        t = (m,max_score) 
print (t)

解决方法

如果您希望继续使用循环,可以创建另一个名为“ best_max_depth”的变量,如果满足if语句条件,则将其值替换为dt_reg.max_depth(这是最佳模型)远)。

但是,我建议您研究GridSearchCV以从最佳模型中提取参数并遍历不同的参数值。

max_reg = None
max_score = 0  
best_max_depth = None
t=()
for m in range(2,6) :
    rf_reg = DecisionTreeRegressor(max_depth=m)
    rf_reg = rf_reg.fit(X_train,Y_train) 
    rf_reg_score = rf_reg.score(X_test,Y_test)
    print (m,rf_reg_score,max_score) 
    if rf_reg_score > max_score :
        max_score = rf_reg_score
        max_reg = rf_reg
        
        best_max_depth = rf_reg.max_depth
        
        t = (m,max_score) 
print (t)
,

试试这个代码 -

myList = list(range(2,6))
scores =[]
for i in myList:
  dt_reg = DecisionTreeRegressor(max_depth=i)
  dt_reg.fit(X_train,Y_train)
  scores.append(dt_reg.score(X_test,Y_test))
print(myList[scores.index(max(scores))])

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...