算法应用

【例子】sklearn决策树后剪枝

作者 : 老饼 发表日期 : 2022-06-26 09:50:49 更新日期 : 2024-06-30 09:53:58
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com



剪枝是决策树预防模型过拟合的措施,剪枝分为预剪枝和后剪枝方法

1. 预剪枝:树构建过程,达到一定条件就停止生长     

2. 后剪枝:等树完全构建后,再剪掉一些节点         

本文讲述后剪枝,预剪枝请参考《sklearn决策树预剪枝》





  01. 决策树-CCP后剪枝简介  



 

本节简单介绍什么是决策树的CCP后剪枝




     什么是决策树的CCP后剪枝     


后剪枝一般指的是CCP代价复杂度剪枝法(Cost Complexity Pruning),
在树构建完成后,对树进行剪枝简化,使以下损失函数最小化:
 
 :叶子节点个数                     
  :所有样本个数                      
  :第 i 个叶子节点上的样本数 
 : 第i个叶子节点的损失函数  
 :待定系数,用于惩罚节点个数,引导模型用更少的节点 
 
损失函数既考虑了代价,又考虑了树的复杂度,所以叫代价复杂度剪枝法,
实质就是在树的复杂度与准确性之间取得一个平衡点。
✍️备注:在sklearn中,如果criterion设为GINI,则Li是每个叶子节点的GINI系数,如果设为entropy,则是熵






  02. 决策树-后剪枝操作过程  




决策树使用CCP后剪枝的具体操作过程如下:




    (1)  查看CCP路径   


 计算CCP路径,查看alpha与树质量的关系:
 
构建好树后,我们可以通过clf.cost_complexity_pruning_path(X, y) 查看树的CCP路径
 Demo代码如下:
# -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
#----------------数据准备----------------------------
iris = load_iris()                          # 加载数据
X = iris.data
y = iris.target
#---------------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,ccp_alpha=0)        
clf = clf.fit(X, y)     
#-------计算ccp路径-----------------------
pruning_path = clf.cost_complexity_pruning_path(X, y)
#-------打印结果---------------------------    
print("\n====CCP路径=================")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities']) 
运行结果如下:
 

  它的意思是:
 
时,树的不纯度为 0.02666            
 时,树的不纯度为 0.03082
 时,树的不纯度为 0.04387
........
其中,树的不纯度指的是损失函数的前部分
也即所有叶子的不纯度(gini或者熵)加权和.
 ✍️ 小贴士 :ccp_path只提供树的不纯度,如果还需要alpha对应的其它信息,
则可以将alpha代入模型中训练,从训练好的模型中获取




  (2) 根据CCP路径剪树   


根据树的质量,选定alpha进行剪树
 
我们选择一个可以接受的树不纯度,找到对应的alpha,然后重新训练决策树
 例如,我们可接受的树不纯度为0.0735,则alpha可设为0.1(在0.02966与0.25979之间)
对模型重新以参数ccp_alpha=0.1进行训练,即可得到剪枝后的决策树
 
完整代码如下:
 # -*- coding: utf-8 -*-
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np

#--------数据准备-----------------------------------
iris = load_iris()                          # 加载数据
X = iris.data
y = iris.target
#-------模型训练---------------------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,random_state=0,ccp_alpha=0)        
clf = clf.fit(X, y)     
#-------计算ccp路径------------------------------
pruning_path = clf.cost_complexity_pruning_path(X, y)

#-------打印结果---------------------------------   
print("\n====CCP路径=================")
print("ccp_alphas:",pruning_path['ccp_alphas'])
print("impurities:",pruning_path['impurities'])    

#------设置alpha对树后剪枝-----------------------
clf = tree.DecisionTreeClassifier(min_samples_split=10,random_state=0,ccp_alpha=0.1)        
clf = clf.fit(X, y) 
#------自行计算树纯度以验证-----------------------
is_leaf =clf.tree_.children_left ==-1
tree_impurities = (clf.tree_.impurity[is_leaf]* clf.tree_.n_node_samples[is_leaf]/len(y)).sum()
#-------打印结果--------------------------- 
print("\n==设置alpha=0.1剪枝后的树纯度:=========\n",tree_impurities)
运行结果如下:
 





     关于CCP路径的计算过程   


对于CCP路径的计算过程,本文不再重复讲解,可参考:
1.《决策树后剪枝原理:CCP剪枝法》                      
2.《决策树(sklearn)中CCP路径计算的实现方式.py》







 End 





联系老饼