将 Python export_text 决策规则转换为 SAS IF THEN DO;结束代码

问题描述

我正在尝试将 Python 的 sklearn.tree export_text 输出转换为 SAS 条件。 此处(在页面最后)已在 Python 中给出了针对此问题的解决方案:How to extract the decision rules from scikit-learn decision-tree?

我尝试修改生成 SAS 代码代码,但在处理嵌套 DO 时遇到问题;结束;

这是我的代码(我创建了一个数据步骤):

def get_sas_from_text(tree,tree_id,features,text,spacing=2):
    # tree is a decision tree from a RandomForestClassifier for instance
    # tree id is a number I use for naming the table I create
    # features is a list of features names
    # text is the output of the export_text function from sklearn.tree
    # spacing is used to handle the file size
    skip,dash = ' '*spacing,'-'*(spacing-1)
    code = 'data decision_tree_' + str(tree_id) + ';'
    code += ' set input_data; '
    n_end_last = 0 # Number of 'END;' to add at the end of the data step
    splitted_text = text.split('\n')
    text_list = []
    for i,line in enumerate(splitted_text):
        line = line.rstrip().replace('|',' ')

        # Handling rows for IF conditions
        if '<' in line or '>' in line:
            line,val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash,'if')
            line = '{} {:g} THEN DO;'.format(line,float(val))
            n_end_last += 1 # need to add an END;
            if i > 0 and 'PREDICTED_VALUE' in text_list[i-1]: # If there's a PREDICTED_VALUE in line above,then condition is ELSE DO
                line = "ELSE DO; " + line
                n_end_last += 1 # add another END
        # Handling rows for PREDICTED_VALUE
        else:
            line = line.replace(' {} class:'.format(dash),'PREDICTED_VALUE =')
            line += ';'
            line += '\n end;' # Immediately add END after PREDICTED_VALUE = .. ;
            n_end_last -= 1
        text_list.append(line)
        code += skip + line + '\n'
    code = code[:-1] 
    code += 'end; + '\n''* n_end_last # add END;
    code += 'run;'
    return code

以下是 iris 数据集的示例:

import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import _tree
import string
from sklearn.tree import export_text

iris = datasets.load_iris()

data=pd.DataFrame({
    'sepal length':iris.data[:,0],'sepal width':iris.data[:,1],'petal length':iris.data[:,2],'petal width':iris.data[:,3],'species':iris.target
})

X=data[['sepal length','sepal width','petal length','petal width']]  # Features
y=data['species']  # Labels

# Split dataset into training set and test set
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3) # 70% training and 30% test

clf=RandomForestClassifier(n_estimators=100)

#Train the model using the training sets y_pred=clf.predict(X_test)
clf.fit(X_train,y_train)

# Function to export the tree rules' code : in Python (works) and in SAS (issue with DO; END;
def export_code(tree,feature_names,max_depth=100,spacing=2):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree,feature_names=features,max_depth=max_depth,decimals=6,spacing=spacing-1)

    code_sas = get_sas_from_text(tree,res,spacing)
    code_py = get_py_from_text(tree,spacing)
    return res,code_sas,code_py # to take a look at the different code outputs

# Python function
def get_py_from_text(tree,spacing):
    skip,'-'*(spacing-1)
    code = 'def decision_tree_'+ str(tree_id) + '({}):\n'.format(','.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in text.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line,'if')
            line = '{} {:g}:'.format(line,float(val))
        else:
            line = line.replace(' {} class:'.format(dash),'return')
        code += skip + line + '\n'

    return code

然后我在 SAS 和 Python 中获取生成代码以用于树的决策规则:

# Rules for first decision tree (there are 100 of them)
exported_text,sas_text,py_text = export_code(clf[0],iris.feature_names)

以下是 Python 中第一棵树的决策规则:

def decision_tree_0(sepal_length__cm_,sepal_width__cm_,petal_length__cm_,petal_width__cm_):
  # DecisionTreeClassifier(max_features='auto',random_state=1087864992)
  if sepal_length__cm_ <= 5.35:
    if sepal_width__cm_ <= 2.7:
      return 1.0
    if sepal_width__cm_ > 2.7:
      return 0.0
  if sepal_length__cm_ > 5.35:
    if petal_width__cm_ <= 1.75:
      if petal_length__cm_ <= 2.5:
        return 0.0
      if petal_length__cm_ > 2.5:
        if sepal_length__cm_ <= 7.1:
          if petal_width__cm_ <= 1.45:
            if petal_length__cm_ <= 5.15:
              return 1.0
            if petal_length__cm_ > 5.15:
              return 2.0
          if petal_width__cm_ > 1.45:
            return 1.0
        if sepal_length__cm_ > 7.1:
          return 2.0
    if petal_width__cm_ > 1.75:
      return 2.0

我的程序 get_sas_from_text 的问题是“END;”语句在代码中并没有很好地放置,这不符合作为输入给出的决策规则(与相应的 Python 函数相比:

data decision_tree_0;

     set input_data;

 

     if sepal_length__cm_ <= 5.35 THEN

          DO;

                if sepal_width__cm_ <= 2.7 THEN

                     DO;

                          PREDICTED_VALUE = 1.0;

                     end;

                ELSE

                     DO;

                          if sepal_width__cm_ > 2.7 THEN

                               DO;

                                     PREDICTED_VALUE = 0.0;

                               end;

                          ELSE

                               DO;

                                     if sepal_length__cm_ > 5.35 THEN

                                          DO;

                                               if petal_width__cm_ <= 1.75 THEN

                                                    DO;

                                                          if petal_length__cm_ <= 2.5 THEN

                                                               DO;

                                                                    PREDICTED_VALUE = 0.0;

                                                               end;

                                                          ELSE

                                                               DO;

                                                                    if petal_length__cm_ > 2.5 THEN

                                                                         DO;

                                                                              if sepal_length__cm_ <= 7.1 THEN

                                                                                    DO;

                                                                                         if petal_width__cm_ <= 1.45 THEN

                                                                                               DO;

                                                                                                    if petal_length__cm_ <= 5.15 THEN

                                                                                                         DO;

                                                                                                              PREDICTED_VALUE = 1.0;

                                                                                                         end;

                                                                                                    ELSE

                                                                                                         DO;

                                                                                                              if petal_length__cm_ > 5.15 THEN

                                                                                                                    DO;

                                                                                                                         PREDICTED_VALUE = 2.0;

                                                                                                                    end;

                                                                                                              ELSE

                                                                                                                    DO;

                                                                                                                         if petal_width__cm_ > 1.45 THEN

                                                                                                                              DO;

                                                                                                                                   PREDICTED_VALUE = 1.0;

                                                                                                                              end;

                                                                                                                         ELSE

                                                                                                                              DO;

                                                                                                                                  if sepal_length__cm_ > 7.1 THEN

                                                                                                                                         DO;

                                                                                                                                              PREDICTED_VALUE = 2.0;

                                                                                                                                         end;

                                                                                                                                   ELSE

                                                                                                                                         DO;

                                                                                                                                              if petal_width__cm_ > 1.75 THEN

                                                                                                                                                   DO;

                                                                                                                                                        PREDICTED_VALUE = 2.0;

                                                                                                                                                   end;

                                                                                                                                              ;

                                                                                                                                         end;

                                                                                                                              end;

                                                                                                                    end;

                                                                                                        end;

                                                                                               end;

                                                                                    end;

                                                                         end;

                                                               end;

                                                    end;

                                          end;

                               end;

                     end;

          end;

run;

您知道如何定位“END;”吗?生成的数据步骤中的语句? 非常感谢您。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...