问题描述
目前正在对虹膜数据集进行分类练习,但我已经到了不能完全确定发生了什么的地步。我想我将一朵新花的假设尺寸传递到模型中,它输出了模型认为花是什么的预测,但我不确定。
species_id = clfr.predict([[1,5,4,6]])
iris.target_names[species_id]
print(iris.target_names[species_id])
这是我所有的代码:
# Importing required libraries
import numpy as np
import pandas as pd
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.datasets import load_iris
import sklearn.metrics as metrics
# Loading datasets
iris = load_iris()
# Convert to pandas dataframe
iris_data = pd.DataFrame({
'sepal length':iris.data[:,0],'sepal width':iris.data[:,1],'petal length':iris.data[:,2],'petal width':iris.data[:,3],'species':iris.target
})
iris_data.head()
# printing categories (setosa,versicolor,virginica)
print(iris.target_names)
# print flower features
print(iris.feature_names)
# setting independent (X) and dependent (Y) variables
X = iris_data[['sepal length','sepal width','petal length','petal width']] # Features
Y = iris_data['species'] # Labels
# printing feature data
print(X[0:5])
# printing dependent variable values (0 = setosa,1 = versicolor,3 = virginica)
print(Y)
# splitting into train and test sets
X_train,X_test,y_train,y_test = train_test_split(X,Y,test_size = 0.3,random_state = 100)
# defining random forest classifier
clfr = RandomForestClassifier(random_state = 100)
clfr.fit(X_train,y_train)
# making prediction
Y_pred = clfr.predict(X_test)
# checking model accuracy
print("Accuracy:",metrics.accuracy_score(y_test,Y_pred))
cm = np.array(confusion_matrix(y_test,Y_pred))
print(cm)
# making predictions on new data
species_id = clfr.predict([[1,6]])
iris.target_names[species_id]
print(iris.target_names[species_id])
解决方法
您的想法是正确的,predict()
方法的输入是一种新的“未知”花的特征,您的模型可以预测其种类。
您在这些特征上训练您的模型:
['sepal length','sepal width','petal length','petal width']
您的输入反映了这一点,其中 [1,5,4,6]
是新花的每个特征的值。
所以,
- '萼片长度' = 1
- '萼片宽度' = 5
- '花瓣长度' = 4
- '花瓣宽度' = 6
您的输入是二维的(即列表中的列表:[[1,6]]
)的原因是您可以传递多个新花以一次获得多个预测:
如果您通过 [[1,6],[2,3,1,3]]
,您将获得 2 个花种预测(每朵新花 1 个)。