数据分析教程-决策树实战二

uwb定位| 2022-09-14

在学习了上一节决策树的原理之后,你有没有想动手实践下的冲动呢,今天我们就来用决策树进行项目实战。

决策树的应用场景是非常广泛的,在各行各业都有应用,并且有非常良好的表现。金融行业的风险贷款评估,医疗行业的疾病诊断,电商行业的销售预测等等。

sklearn 中的决策树

首先我们先来了解下如何在 sklearn 中使用决策树模型。

在 sklearn 中,可以使用如下方式来构建决策树分类器

fromsklearn.treeimportDecisionTreeClassifier
clf=DecisionTreeClassifier(criterion='entropy')

其中的 criterion 参数,就是决策树算法,可以选择 entropy,就是基于信息熵;而 gini,就是基于基尼系数。

构建决策树的一些重要参数,整理如下:

参数名 作用
criterion 特征选择标准,可选参数,默认是 gini,可以设置为 entropy。ID3 算法使用的是 entropy,CART 算法使用的则是 gini。
splitter 特征划分点选择标准,可选参数,默认是 best,可以设置为 random。best 参数是根据算法选择最佳的切分特征,例如 gini、entropy。random 随机的在部分划分点中找局部最优的划分点。默认的"best"适合样本量不大的时候,而如果样本数据量非常大,此时决策树构建推荐"random"。
max_features 在划分数据集时考虑的最多的特征值数量,为 int 或 float 类型。一般情况如果特征不是很多,少于50时,使用默认值 None 即可。
max_depth 决策树最大深,可选参数,默认是 None。
min_samples_split 内部节点再划分所需最小样本数,可选参数,默认是2。当节点的样本数少于 min_samples_split 时,不再继续分裂。
min_samples_leaf 叶子节点最少样本数,可选参数,默认是1。如果某叶子节点数目小于这个阈值,就会和兄弟节点一起被裁剪掉。叶结点需要最少的样本数,也就是最后到叶子节点,需要多少个样本才能算一个叶子节点。
min_weight_fraction_leaf 叶子节点最小的样本权重和,可选参数,默认是0。这个值限制了叶子节点所有样本权重和的最小值,如果小于这个值,则会和兄弟节点一起被剪枝。
max_leaf_nodes 最大叶子节点数,默认是 None。可以通过设置最大叶子节点数,防止过拟合。特征不多时,不用考虑该参数。
class_weight 类别权重,可选参数,默认是 None,也可以字典、字典列表、balanced。指定样本各类别的的权重,主要是为了防止训练集某些类别的样本过多,导致训练的决策树过于偏向这些类别。如果使用 balanced,则算法会自己计算权重,样本量少的类别所对应的样本权重会高。
random_state 随机数种子,可选参数,默认是None。如果设置了随机数种子,那么相同随机数种子,不同时刻产生的随机数也是相同的。
min_impurity_split 信息增益的阈值,如果信息增益小于这个值,则决策树不再增长,该节点不再生成节点,即为叶子节点。
presort 数据是否预排序,可选参数,默认为 False。如果样本量少或者限制了一个深度很小的决策树,设置为 True 可以让划分点选择更加快,决策树建立的更加快。

泰坦尼克预测

在了解了 sklearn 中构建决策树的方式和相关参数后,我们就可以进行真正的决策树构建了,并解决实际问题。

首先我们先使用最为经典的泰坦尼克数据集来预测下乘客的生存情况,你应该还记得,我们在数据清洗章节已经讲解过该数据集是如何清洗的,现在我们继续使用清洗之后的数据,用决策树的方式预测结果。

数据清洗

具体的清洗思路,可以回顾下第4节

importnumpyasnp
importpandasaspddf=pd.read_csv('titanic_data.csv')
data=df.copy()
data['age'].fillna(df['age'].median(skipna=True),inplace=True)
data.drop(columns=['cabin'],inplace=True)
data['embarked'].fillna(df['embarked'].value_counts().idxmax(),inplace=True)
data.dropna(axis=0,how='any',inplace=True)
data.isnull().sum()#查看是否还有空值
data['alone']=np.where((data["sibsp"]+data["parch"])>0,0,1)
data.drop('sibsp',axis=1,inplace=True)
data.drop('parch',axis=1,inplace=True)
data=pd.get_dummies(data,columns=["embarked","sex"])
data.drop('name',axis=1,inplace=True)
data.drop('ticket',axis=1,inplace=True)

使用决策树做预测

导入需要的库

fromsklearn.treeimportDecisionTreeClassifier
fromsklearn.model_selectionimporttrain_test_split
fromsklearn.metricsimportaccuracy_score

特征选择

feature=["age","fare","alone","pclass","embarked_C","embarked_S","embarked_Q","sex_male","sex_female"]

创建特征(X)和类别标签(y)

X=data[feature]
y=data['survived']

划分训练集和测试集

X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=2)

构建决策树模型并训练

clf=DecisionTreeClassifier(criterion='entropy')
clf.fit(X_train,y_train)
>>>
DecisionTreeClassifier(class_weight=None,criterion='entropy',max_depth=None,
max_features=None,max_leaf_nodes=None,
min_impurity_decrease=0.0,min_impurity_split=None,
min_samples_leaf=1,min_samples_split=2,
min_weight_fraction_leaf=0.0,presort=False,random_state=None,
splitter='best')

此时打印出来的就是我们构建好的决策树及相关参数

决策树预测

pred_labels=clf.predict(X_test)

查看分类器准确率

print("Accuracy:",accuracy_score(y_test,pred_labels))
>>>
Accuracy:0.7595419847328244

决策树可视化

为了做决策树的可视化,我们需要另外两个库,pydotplus 和 Grsphviz。

可以直接使用 pip 安装 pydotplus

pipinstallpydotplus

对于 Graphviz,需要下载对应的软件

https://www.graphviz.org/download/

你可以选择对应的操作系统版本,安装好之后,需要添加环境变量,之后就可以使用了。

此处以 windows 为例

下载并安装好软件

设置环境变量

按快捷键 win+r,在出现的运行对话框中输入sysdm.cpl,点击确定,出现如下对话框:

依次选择“高级”,“环境变量”,在系统变量中选择 Path 编辑

选择新建,并把 Graphviz 的安装路径填入

至此,已经可以正常只用 Graphviz 了。

代码实现

fromsklearn.treeimportexport_graphviz
fromsklearn.externals.siximportStringIO
importpydotplusdot_data=StringIO()
export_graphviz(clf,out_file=dot_data,
filled=True,rounded=True,
special_characters=True,feature_names=feature,class_names=['0','1'])
graph=pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('diabetes.png')

现在在当前目录下就成功生成了一个叫做 diabetes.png 的图片了,内容如下:

由于树比较大,更加高清的图片可以在这里下载

https://github.com/zhouwei713/DataAnalyse/tree/master/Decision_tree

我们简单看下根节点吧,这棵决策树在根节点选择的是 sex_male 这个特征来做分类的,小于0.5的数据有636个,分到了左边,即 class 为1,大于0.5的数据有410个,分到了右边,即 class 为0。后面的分类都是以此类推的