机器学习——十大数据挖掘之一的决策树CART算法
今天是机器学习专题的第23篇文章,我们今天分享的内容是十大数据挖掘算法之一的CART算法。
CART算法全称是Classification and regression tree,也就是分类回归树的意思。和之前介绍的ID3和C4.5一样,CART算法同样是决策树模型的一种经典的实现。决策树这个模型一共有三种实现方式,前面我们已经介绍了ID3和C4.5两种,今天刚好补齐这最后一种。
算法特点
CART称为分类回归树,从名字上我们也看得出来,它既能支持分类又可以支持回归。的确如此,决策树的确支持回归操作,但是我们一般不会用决策树来进行回归。这里面的原因很多,除了树模型拟合能力有限效果不一定好之外,还与特征的模式有关系,树回归模型受到特征的影响非常大。这个部分我们不做太多深入,之后会在回归树的文章当中详细探讨。
正因为回归树模型效果表现都不太理想,所以CART算法实现决策树基本都是用来做分类问题。那么在分类问题上,它与之前的ID3算法和C4.5算法又有什么不同呢?
主要细究起来大约有两点,第一点是CART算法使用Gini指数而不是信息增益来作为划分子树的依据,第二点是CART算法每次在划分数据的时候,固定将整份数据拆分成两个部分,而不是多个部分。由于CART每次将数据拆分成两个部分,所以它对于拆分的次数没有限制,而C4.5算法对特征进行了限制,限制了每个特征最多只能使用一次。因为这一点,同样CART对于剪枝的要求更高,因为不剪枝的话很有可能导致树过度膨胀,以至于过拟合。
Gini指数
在ID3和C4.5算法当中,在拆分数据的时候用的是信息增益和信息增益比,这两者都是基于信息熵模型。信息熵模型本身并没有问题,也是非常常用的模型。唯一的问题是,在计算熵的时候需要涉及到log运算,相比于四则运算来说,计算log要多耗时很多。
Gini指数本质上也是基于信息熵模型,只是我们在计算的时候做了一些转化,从而避免了使用log进行计算,加速了计算的过程。两者的内在逻辑是一样的。那怎么实现的加速计算呢?这里用到了高等数学当中的泰勒展开,我们将log运算通过泰勒公式展开,转化成多项式的计算,从而加速信息熵的计算。
我们来做一个简单的推导:
\ln(x) \approx \ln(x_0) + (x-x_0)\ln'(x_0) + o(x)
\end{aligned}
\]
我们把\(x_0 =1\)代入,可以得到:\(\ln(x)=x – 1 + o(x)\),其中o(x)是关于x的高阶无穷小。我们把这个式子套入信息熵的公式当中:
H(x) &= -\sum_{i=1}^k p_i\ln p_i \\
&\approx \sum_{i=1}^k p_i(1-p_i)
\end{aligned}
\]
这个就是Gini指数的计算公式,这里的pi表示类别i的概率,其实就是类别i的样本占全体样本的比例。那么上面的式子也可以看成是从数据集当中抽取两条样本,它们类别不一致的概率。
因此Gini指数越小,说明数据集越集中,也就是纯度越高。它的概念等价于信息熵,熵越小说明信息越集中,两者的概念是非常近似的。所以当我们使用Gini指数来作为划分依据的时候,选择的是切分之后Gini指数尽量小的切分方法,而不是尽量大的。
从上面的公式当中,我们可以发现相比于信息熵的log运算,Gini指数只需要简单地计算比例和基础运算就可以得到结果了,显然运算速度要快得多。并且由于是通过泰勒展开逼近的,整体的性能也并不差,我们可以看下下面这张经典的图感受一下:
从上图当中可以看出来,Gini指数和信息熵的效果非常接近,一样可以非常好地反应数据划分的纯度。
拆分与剪枝
刚才我们介绍CART算法特性的时候提到过,CART算法每次拆分数据都是二分的,这点和C4.5处理连续性特征的逻辑很像。但有两点不同,第一点是CART对于离散型和连续性特征都如此操作,另外一点是,CART算法当中一个特征可以重复使用。
举个例子,在之前的算法当中,比如说西瓜的直径是一个特征。那么当我们判断过西瓜的直径小于10cm之后,西瓜的直径这个特征就会从数据当中移除,之后再也不会用到。但是在CART算法当中不是如此,比如当我们先后根据西瓜的直径以及西瓜是否有藤这两个特征对数据进行拆分之后,对于ID3和C4.5算法来说,西瓜的直径这个特征已经不可以再用来作为划分的依据了,但是CART算法当中可以,我们仍然可以继续使用之前已经用过的特征。
我们用一张图来展示,大概是下面这个样子:
我们观察一下最左侧的子树,直径这个特征出现了不止一次,这其实是很合理的。然而这也会有一个问题,就是由于没有了特征只能用一次这个限制,这样会导致这棵树无限膨胀,尤其是在连续性特征很多的情况下,很容易陷入过拟合。为了放置过拟合,增加模型的泛化能力,我们需要对生成的这棵树进行剪枝。
剪枝的方案主流的有两种,一种是预剪枝,一种是后剪枝。所谓的预剪枝,即是在生成树的时候就对树的生长进行限制,防止过度拟合。而后剪枝则是在树已经生成之后,对过拟合的部分进行修剪。其中预剪枝比较容易理解,比如我们可以限制决策树在训练的时候每个节点的数据只有在达到一定数量的情况下才会进行分裂,否则就成为叶子节点保留。或者我们可以限制数据的比例,当节点中某个类别的占比超过阈值的时候,也可以停止生长。
后剪枝相对来说复杂一些,需要我们在生成树之后通过一些机制寻找可以剪枝的部分,对整棵树进行修剪。比如在CART算法当中常用的剪枝策略是CCP,它的英文全写是Cost-Complexity Pruning,即代价复杂度剪枝。这个策略设计了一个指标来衡量一棵子树的复杂度代价,我们可以对这个代价设置阈值来进行剪枝。
这个策略的精髓在于下面这个式子:
\]
这个式子当中的c就是指的剪枝带来的代价,t代表剪枝之后的子树,\(T_t\)表示剪枝之前的子树。R(t)表示剪枝之后的误差代价,\(R(T_t)\)表示剪枝之前的误差代价。其中误差代价的定义是:\(R(t) = r(t) * p(t)\),r(t)是节点t的误差率,p(t)是t上数据占所有数据的比例。
我们来看个例子:
假设我们知道所有数据一共有100条,那么我们代入公式算一下,可以得到\(R(t) = r(t) * p(t) = \frac{11}{23} * \frac{23}{100} = \frac{11}{100}\)
子树的误差代价是:
\]
所以可以得到\(c=\frac{11/100 – 4/100}{3 – 1}=\frac{7}{200}\)
c越大说明剪枝带来的偏差越大,也就是说越不能剪,相反c很小说明偏差不大,可以减掉。我们只需要设置阈值,然后计算每一棵子树的c来判断是否能够剪枝即可。
代码实现
我们之前已经实现过了C4.5算法,再来实现CART可以说是非常简单了,因为它相比于C4.5还少了离散类型这种情况,可以全部当做是连续型类型来处理。
我们只需要把之前的信息增益比改成Gini指数即可:
from collections import Counter
def gini_index(dataset):
dataset = np.array(dataset)
n = dataset.shape[0]
if n == 0:
return 0
# sigma p(1-p) = 1 - sigma p^2
counter = Counter(dataset[:, -1])
ret = 1.0
for k, v in counter.items():
ret -= (v / n) ** 2
return ret
def split_gini(dataset, idx, threshold):
left, right = [], []
n = dataset.shape[0]
# 根据阈值拆分,拆分之后计算新的Gini指数
for data in dataset:
if data[idx] < threshold:
left.append(data)
else:
right.append(data)
left, right = np.array(left), np.array(right)
# 拆分成两半之后,乘上所占的比例
return left.shape[0] / n * gini_index(left) + right.shape[0] / n * gini_index(right)
然后选择拆分的函数稍微调整一下,因为Gini指数越小越好,之前的信息增益和信息增益比都是越大越好。代码的框架基本上也没有变动,只是做了一些微调:
def choose_feature_to_split(dataset):
n = len(dataset[0])-1
m = len(dataset)
# 记录最佳Gini,特征和阈值
bestGini = 1.0
feature = -1
thred = None
for i in range(n):
threds = get_thresholds(dataset, i)
for t in threds:
# 遍历所有的阈值,计算每个阈值的信息增益比
ratio = split_gini(dataset, i, t)
if ratio < bestGini:
bestGini, feature, thred = ratio, i, t
return feature, thred
建树和预测的部分都和之前C4.5算法基本一致,只需要去掉离散类型的判断即可,大家可以参考一下之前文章当中的代码。
总结
到这里,我们关于决策树模型的内容就算是结束了,我们从基本的决策树原理,再到ID3、C4.5以及CART算法,都已经囊括了。这些知识储备足以应对面试当中关于决策树模型的问题了。
虽然在实际的生产过程当中,我们已经用不到决策树了,还不是基本用不到,几乎是完全用不到。但是它的思想非常重要,是后续很多模型的基础,比如随机森林、GBDT等模型,都是在决策树的基础上建立起来的。所以我们深入理解决策树的原理对于我们后续的进阶学习非常重要。
最后, 我把完整的代码发在了paste.ubuntu上,需要的同学可以在公众号后台回复“决策树”获取。
如果喜欢本文,可以的话,请点个关注,给我一点鼓励,也方便获取更多文章。