ValueError: y 必须是整数数组找到对象尝试将数组作为 y.astype(np.integer) 传递

问题描述

这是我的代码

import pandas as pd
from sklearn.model_selection import train_test_split,cross_val_score
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.tree import DecisionTreeClassifier,plot_tree,export_graphviz,export_text
from sklearn.pipeline import Pipeline
from sklearn.metrics import confusion_matrix,classification_report,accuracy_score,roc_curve,auc,f1_score,roc_auc_score
import warnings; warnings.simplefilter('ignore')

data_files = 'dataset_for_learning_decision_tree.xlsx'

data = pd.read_excel(data_files)
train_data = data[['title','category','processed_title']]

categories=train_data['category']
labels=list(set(categories))

X_train,X_test,y_train,y_test = train_test_split(train_data['processed_title'],train_data['category'],test_size=0.2,random_state=57)

vectorizer = CountVectorizer()
X = vectorizer.fit_transform(X_train)
decisiontree=DecisionTreeClassifier()

model = Pipeline([('vect',vectorizer),('tfidf',TfidfTransformer()),('clf',decisiontree),])
model.fit(X_train,y_train)

predicted = model.predict(X_test)
confusion_matrix(y_test,predicted)
print('accuracy_score',accuracy_score(y_test,predicted))
print('Reporting...')
print(classification_report(y_test,predicted))

import numpy as np
from mlxtend.plotting import plot_decision_regions

X=np.array(X_train)
y=np.array(y_train)
plot_decision_regions(X=X,y=y,clf=model.named_steps['clf'])

我想画一个 plot_decision_region。 但是,当我执行此代码时,出现与标题相同的错误。 使用 y=y.astype(np.integer) 运行时,我收到错误,例如 ValueError: invalid literal for int() with base 10: 'depression'。我该如何解决

解决方法

首先将类标签转换为整数,

import numpy as np
from mlxtend.plotting import plot_decision_regions

X = np.array(X_train)
y = np.array(y_train)
d = {'addiction':0,'depression':1,'normal':2}
y = list(map(lambda i : d[i],y))
plot_decision_regions(X=X,y=y,clf=model.named_steps['clf'])

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...