算法应用

【例子】sklearn决策树预剪枝

作者 : 老饼 发表日期 : 2022-06-26 09:50:33 更新日期 : 2024-06-30 03:47:28
本站原创文章,转载请说明来自《老饼讲解-机器学习》www.bbbdata.com



剪枝是决策树预防模型过拟合的措施,剪枝分为预剪枝和后剪枝方法

1. 预剪枝: 树构建过程,达到一定条件就停止生长    

2. 后剪枝:等树完全构建后,再剪掉一些节点        

本文讲述预剪枝,预剪枝请参考《sklearn决策树后剪枝》




  01. sklearn中决策树如何预剪枝  



本节讲解在sklearn中如何对决策树进行预剪枝




     sklearn中如何预剪枝     


预剪枝是树构建过程,达到一定条件就停止生长,
在sklearn中,实际就是调参,通过设置树的生长参数,来达到预剪枝的效果
相关训练参数如下:
min_samples_leaf               :
叶子节点最小样本数       
 min_samples_split              :节点分枝最小样本个数     
 max_depth                           :树分枝的最大深度            
 min_weight_fraction_leaf   :叶子节点最小权重和         
 min_impurity_decrease      :节点分枝最小纯度增长量   
 max_leaf_nodes                  :最大叶子节点数               
 一般来说,只调这三个:max_depth,min_samples_leaf,min_samples_split    








  02. 决策树预剪枝-实例讲解  




本节通过例子讲解决策树预剪枝时的思路与操作




     预剪枝使用的思路    


预剪枝的整体思路如下:
1. 先用默认值,让树完整生长                                                                     
2. 参考完全生长的决策树的信息,分析树有没有容易过拟合的表现              
3. 通过相关参数,对过分生长的节点作出限制,以新参数重新训练决策树    




      决策树预剪枝-例子解说      


(1) 先用默认值预观察完整生长的树
 
 Demo代码如下:
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
import pandas as pd
#--------数据加载-----------------------------------
iris = load_iris()                          # 加载数据
X = iris.data
y = iris.target
#-------用最优参数训练模型------------------
clf = tree.DecisionTreeClassifier(random_state=0)
clf = clf.fit(X, y)  
depth = clf.get_depth()
leaf_node = clf.apply(X)
#-----观察各个叶子节点上的样本个数---------
df  = pd.DataFrame({"leaf_node":leaf_node,"num":np.ones(len(leaf_node)).astype(int)})
df  = df.groupby(["leaf_node"]).sum().reset_index(drop=False)
df  = df.sort_values(by='num').reset_index(drop=True)
print("\n==== 树深度:",depth," ============")
print("==各个叶子节点上的样本个数:==")
print(df)
运行结果如下:
 
==== 树深度: 5  ============
==各个叶子节点上的样本个数:==
        leaf_node   num
0          6         1
1         11        1
2         15        1
3         10        2
4         14        2
5          8         3
 6         16       43
  7          5        47
  8          1        50
(2) 通过参数限制节点过分生长
 
我们可以看到,有很多叶子节点只有一两个样本,这样很容易过拟合,
因此我们把min_samples_leaf 调为3:
#-------用新调整的参数训练模型------------------
clf = tree.DecisionTreeClassifier(random_state=0,max_depth=4,min_samples_leaf=10)
clf = clf.fit(X, y)  
depth = clf.get_depth()
leaf_node = clf.apply(X)
#-----观察各个叶子节点上的样本个数---------
df  = pd.DataFrame({"leaf_node":leaf_node,"num":np.ones(len(leaf_node)).astype(int)})
df  = df.groupby(["leaf_node"]).sum().reset_index(drop=False)
df  = df.sort_values(by='num').reset_index(drop=True)
print("\n==== 树深度:",depth," ============")
print("==各个叶子节点上的样本个数:==")
print(df)
运行结果如下:
 
==== 树深度: 4  ============
==各个叶子节点上的样本个数:==
   leaf_node  num
0          6         11
1          9         11
2          7         14
3          5        29
 4          10       35
 5          1         50
可以看到,最少的一个叶子,也有11个样本了,这样的决策树泛化能力更加好




本文仅讲预剪枝的基本操作,在实际中,需要更灵活的思路,








 End 








联系老饼