智能决策分析

| 2022-09-22

智能决策

智能决策支持系统是人工智能(AI,Artificial Intelligence)和DSS相结合,应用专家系统(ES,Expert System)技术,使DSS能够更充分地应用人类的知识,如关于决策问题的描述性知识,决策过程中的过程性知识,求解问题的推理性知识,通过逻辑推理来帮助解决复杂的决策问题的辅助决策系统。

决策树

在现实生活中,我们会遇到各种选择,不论是选择男女朋友,还是挑选水果,都是基于以往的经验来做判断。如果把判断背后的逻辑整理成一个结构图,你会发现它实际上是一个树状图,这就是决策树。

工作原理

决策树基本上就是把我们以前的经验总结出来。如果我们要出门打篮球,一般会根据“天气”、“温度”、“湿度”、“刮风”这几个条件来判断,最后得到结果:去打篮球?还是不去?

智能决策

上面这个图就是一棵典型的决策树。我们在做决策树的时候,会经历两个阶段:构造和剪枝。

构造

构造就是生成一棵完整的决策树。简单来说,构造的过程就是选择什么属性作为节点的过程,那么在构造过程中,会存在三种节点:

1.根节点:就是树的最顶端,最开始的那个节点。在上图中,“天气”就是一个根节点;

2.内部节点:就是树中间的那些节点,比如说“温度”、“湿度”、“刮风”;

3.叶节点:就是树最底部的节点,也就是决策结果。

节点之间存在父子关系。比如根节点会有子节点,子节点会有子子节点,但是到了叶节点就停止了,叶节点不存在子节点。那么在构造过程中,你要解决三个重要的问题:

1.选择哪个属性作为根节点;

2.选择哪些属性作为子节点;

3.什么时候停止并得到目标状态,即叶节点。

剪枝

剪枝就是给决策树瘦身,这一步想实现的目标就是,不需要太多的判断,同样可以得到不错的结果。之所以这么做,是为了防止“过拟合”(Overfitting)现象的发生。

过拟合:指的是模型的训练结果“太好了”,以至于在实际应用的过程中,会存在“死板”的情况,导致分类错误。

欠拟合:指的是模型的训练结果不理想。

智能决策

造成过拟合的原因:

一是因为训练集中样本量较小。如果决策树选择的属性过多,构造出来的决策树一定能够“完美”地把训练集中的样本分类,但是这样就会把训练集中一些数据的特点当成所有数据的特点,但这个特点不一定是全部数据的特点,这就使得这个决策树在真实的数据分类中出现错误,也就是模型的“泛化能力”差。

泛化能力:指的分类器是通过训练集抽象出来的分类能力,你也可以理解是举一反三的能力。如果我们太依赖于训练集的数据,那么得到的决策树容错率就会比较低,泛化能力差。因为训练集只是全部数据的抽样,并不能体现全部数据的特点。

剪枝的方法:

预剪枝:在决策树构造时就进行剪枝。方法是,在构造的过程中对节点进行评估,如果对某个节点进行划分,在验证集中不能带来准确性的提升,那么对这个节点进行划分就没有意义,这时就会把当前节点作为叶节点,不对其进行划分。

后剪枝:在生成决策树之后再进行剪枝。通常会从决策树的叶节点开始,逐层向上对每个节点进行评估。如果剪掉这个节点子树,与保留该节点子树在分类准确性上差别不大,或者剪掉该节点子树,能在验证集中带来准确性的提升,那么就可以把该节点子树进行剪枝。方法是:用这个节点子树的叶子节点来替代该节点,类标记为这个节点子树中最频繁的那个类。

如何判断要不要去打篮球?

智能决策

我们该如何构造一个判断是否去打篮球的决策树呢?再回顾一下决策树的构造原理,在决策过程中有三个重要的问题:将哪个属性作为根节点?选择哪些属性作为后继节点?什么时候停止并得到目标值?

显然将哪个属性(天气、温度、湿度、刮风)作为根节点是个关键问题,在这里我们先介绍两个指标:纯度和信息熵。

纯度

决策树的构建是基于样本概率和纯度来进行的,判断数据集是否“纯”可以通过三个公式进行判断:Gini系数、熵(Entropy)、错误率。

三个公式的值越大,表示数据越不纯。值越小,表示数据越纯。

例:偿还贷款的能力。P(1) = 7/10 = 0.7;可以偿还的概率;P(2) = 3/10 = 0.3;无法偿还的概率;

智能决策

智能决策

Error = 1 - max {p(i)} (i =1 ~ n) = 1 - 0.7 = 0.3

如果只有两种分类情况,随着两种情况发生的概率的改变,最后根据三种公式的计算所得:

智能决策

可以发现,三种公式的效果差不多,一般情况使用熵公式。

信息增益

信息增益指的就是划分可以带来纯度的提高,信息熵的下降。它的计算公式,是父亲节点的信息熵减去所有子节点的信息熵。在计算的过程中,我们会计算每个子节点的归一化信息熵,即按照每个子节点在父节点中出现的概率,来计算这些子节点的信息熵。所以信息增益的公式可以表示为:

智能决策

公式中 D 是父亲节点,Di 是子节点,Gain(D,a)中的 a 作为 D 节点的属性选择。

如何计算信息增益呢?

import numpy as npimport pandas as pdfrom collections import Counterimport mathfrom math import log# 熵# print(-(1 / 3) * log(1 / 3, 2) - (2 / 3) * log(2 / 3, 2))def calc_ent(datasets):    data_length = len(datasets)    label_count = {}    for i in range(data_length):        label = datasets[i][-1]        if label not in label_count:            label_count[label] = 0        label_count[label] += 1    ent = -sum([(p / data_length) * log(p / data_length, 2)                for p in label_count.values()])    # print(ent)    return ent# 经验条件熵def cond_ent(datasets, axis=0):    data_length = len(datasets)    feature_sets = {}    for i in range(data_length):        feature = datasets[i][axis]        if feature not in feature_sets:            feature_sets[feature] = []        feature_sets[feature].append(datasets[i])    cond_ent = sum([(len(p) / data_length) * calc_ent(p)                    for p in feature_sets.values()])    print(cond_ent)    return cond_ent# 信息增益def info_gain(ent, cond_ent):    return ent - cond_entdef info_gain_train(datasets):    count = len(datasets[0]) - 1    print(count)    ent = calc_ent(datasets)    print(ent)    best_feature = []    for c in range(count):        c_info_gain = info_gain(ent, cond_ent(datasets, axis=c))        best_feature.append((c, c_info_gain))        print('特征({}) - info_gain - {:.3f}'.format(labels[c], c_info_gain))    # 比较大小    best_ = max(best_feature, key=lambda x: x[-1])    return '特征({})的信息增益最大,选择为根节点特征'.format(labels[best_[0]])# labels = ["天气", "温度", "湿度", "刮风", '类别']# datasets = pd.DataFrame([#     ["晴", "高", "中", "否", '否'],#     ["晴", "高", "中", "是", '否'],#     ["阴天", "高", "高", "否", '是'],#     ["雨", "高", "高", "否", '是'],#     ["雨", "低", "高", "否", '否'],#     ["晴", "中", "中", "是", '是'],#     ["阴天", "中", "高", "是", '否'],# ])labels = ["天气", "湿度", "刮风", '类别']datasets = pd.DataFrame([    ["晴", "中", "否", '否'],    ["晴", "中", "是", '否'],    ["阴天", "高", "否", '是'],    ["雨", "高", "否", '是']])print(datasets)print(info_gain_train(np.array(datasets)))