决策树算法详解
目录
1. 什么是决策树
决策树(Decision Tree)是一种直观且强大的监督学习算法,可用于分类和回归任务。它的核心思想是通过一系列"是/否"问题将数据逐步划分,最终到达一个叶节点得出预测结果。你可以把它想象成一个倒置的树:根节点是第一个问题,每个分支代表一个答案,叶节点是最终的决策。
以一个简单的"是否出去打球"的例子来理解:
这棵简单的决策树只用了3个特征(天气、湿度、风力)就能做出预测。真实世界的决策树可能有数百个节点,但基本原理完全相同:在每个节点选择最优特征进行分裂,直到满足停止条件。
2. 决策树工作原理
决策树的构建过程是一个递归分裂(Recursive Splitting)的过程,核心步骤如下:
2.1 递归分裂
从根节点开始,算法遍历所有可用特征和可能的分裂点,选择能最大程度降低"不纯度"(impurity)的特征和阈值进行分裂。对分裂后的每个子集重复此过程。
2.2 特征选择
在每个节点,算法需要回答一个关键问题:选择哪个特征来分裂? 答案是选择使得子节点"最纯"的特征。衡量"纯度"的指标有信息增益(Information Gain)、增益率(Gain Ratio)和基尼不纯度(Gini Impurity),它们分别对应 ID3、C4.5 和 CART 三种算法。
2.3 叶节点与停止条件
当满足以下条件之一时,节点变为叶节点,不再继续分裂:
2. 达到最大深度限制(max_depth)
3. 节点中样本数小于最小分裂阈值(min_samples_split)
4. 没有可用特征进行分裂
5. 分裂带来的信息增益小于阈值
分类树的叶节点输出该节点中出现最多的类别(多数表决);回归树的叶节点输出该节点中目标值的均值。
3. 分裂准则
3.1 信息增益 — ID3 算法
信息熵 (Entropy)
H(S) = -Σ pᵢ log₂(pᵢ)
其中 pᵢ 是类别 i 在集合 S 中的比例。熵越大,数据越混乱;熵为 0 表示完全纯净。
信息增益 (Information Gain)
IG(S, A) = H(S) - Σ (|Sᵥ| / |S|) * H(Sᵥ)
信息增益 = 分裂前的熵 - 分裂后各子集的加权熵之和。IG 越大,说明按该特征分裂后数据变得越"纯"。
父节点熵: H(S) = -(9/14)log₂(9/14) - (5/14)log₂(5/14) = 0.940
按特征A分裂为两个子集:
S₁: 6个样本 (4+, 2-) → H(S₁) = -(4/6)log₂(4/6) - (2/6)log₂(2/6) = 0.918
S₂: 8个样本 (5+, 3-) → H(S₂) = -(5/8)log₂(5/8) - (3/8)log₂(3/8) = 0.954
信息增益: IG = 0.940 - (6/14)*0.918 - (8/14)*0.954 = 0.940 - 0.393 - 0.545 = 0.002
按特征B分裂为两个子集:
S₁: 7个样本 (7+, 0-) → H(S₁) = 0 (纯节点!)
S₂: 7个样本 (2+, 5-) → H(S₂) = -(2/7)log₂(2/7) - (5/7)log₂(5/7) = 0.863
信息增益: IG = 0.940 - (7/14)*0 - (7/14)*0.863 = 0.940 - 0 - 0.431 = 0.509
特征B的信息增益(0.509)远大于特征A(0.002),所以选择特征B分裂。
3.2 增益率 — C4.5 算法
分裂信息 (Split Information)
SplitInfo(S, A) = -Σ (|Sᵥ| / |S|) * log₂(|Sᵥ| / |S|)
SplitInfo 衡量按特征A分裂后子集大小的分布均匀程度。特征取值越多,SplitInfo 越大。
增益率 (Gain Ratio)
GR(S, A) = IG(S, A) / SplitInfo(S, A)
增益率通过除以 SplitInfo 来惩罚取值过多的特征,解决了 ID3 偏好多值特征的问题。
SplitInfo(S, B) = -(7/14)log₂(7/14) - (7/14)log₂(7/14) = 1.0
GR(S, B) = 0.509 / 1.0 = 0.509
假设特征C有5个取值,每个分到约2-3个样本,IG=0.52但SplitInfo=2.3:
GR(S, C) = 0.52 / 2.3 = 0.226
虽然特征C的信息增益略高(0.52 vs 0.509),但增益率更低(0.226 vs 0.509)。C4.5会选择特征B,因为它不会过度偏好多值特征。
3.3 基尼不纯度 — CART 算法
基尼不纯度 (Gini Impurity)
Gini(S) = 1 - Σ pᵢ²
基尼系数衡量从集合中随机抽取两个样本,类别不一致的概率。范围 [0, 0.5](二分类),0 表示完全纯净。
加权基尼 (Weighted Gini after split)
Gini_split = Σ (|Sᵥ| / |S|) * Gini(Sᵥ)
CART 选择使加权基尼值最小的分裂方式。注意:CART 只做二叉分裂(每次分成两个子集)。
父节点基尼: Gini = 1 - (9/14)² - (5/14)² = 1 - 0.413 - 0.128 = 0.459
按特征B分裂(阈值=x):
左子集: 7个样本 (7+, 0-) → Gini = 1 - 1² - 0² = 0.000
右子集: 7个样本 (2+, 5-) → Gini = 1 - (2/7)² - (5/7)² = 1 - 0.082 - 0.510 = 0.408
加权基尼: (7/14)*0.000 + (7/14)*0.408 = 0.204
加权基尼从0.459降到0.204,下降显著。CART会在所有可能的分裂中选择加权基尼最小的那个。
4. ID3 vs C4.5 vs CART 对比
| 特性 | ID3 | C4.5 | CART |
|---|---|---|---|
| 分裂准则 | 信息增益 | 增益率 | 基尼不纯度 |
| 树结构 | 多叉树 | 多叉树 | 严格二叉树 |
| 连续特征 | 不支持 | 支持(二分法离散化) | 支持(选最优分裂点) |
| 缺失值处理 | 不支持 | 支持(加权分配) | 支持(代理分裂) |
| 剪枝策略 | 无 | 悲观错误剪枝(PEP) | 代价复杂度剪枝(CCP) |
| 任务类型 | 仅分类 | 仅分类 | 分类 + 回归 |
| 偏好 | 偏好多值特征 | 修正了多值偏好 | 无明显偏好 |
| 提出年份 | 1986 (Quinlan) | 1993 (Quinlan) | 1984 (Breiman) |
| sklearn实现 | 无 | 无 | DecisionTreeClassifier |
实践中,sklearn的决策树实现基于优化的CART算法。如果你使用Python做机器学习,默认就是CART。ID3和C4.5更多出现在学术研究和面试中。
5. Python 从零实现决策树
下面用纯Python实现一个基于信息熵的简单决策树分类器,不依赖任何第三方库(仅用math模块)。
6. Sklearn 决策树实战
6.1 分类树 (DecisionTreeClassifier)
6.2 回归树 (DecisionTreeRegressor)
6.3 特征重要性
7. 剪枝策略
未剪枝的决策树容易过拟合——它可以完美拟合训练集(每个叶节点只有1个样本),但泛化能力很差。剪枝是解决过拟合的关键手段。
7.1 预剪枝 (Pre-pruning)
在树的构建过程中提前停止生长。sklearn 支持的预剪枝参数:
min_samples_split: 节点分裂所需最少样本数,默认2。增大可防止过拟合。
min_samples_leaf: 叶节点最少样本数,默认1。增大可让叶节点更稳定。
max_features: 每次分裂考虑的最大特征数。'sqrt' 或 'log2' 可增加随机性。
max_leaf_nodes: 最大叶节点数。限制树的复杂度。
7.2 后剪枝 — 代价复杂度剪枝 (Cost-Complexity Pruning, CCP)
先让树完全生长,然后从底部开始逐步"剪掉"对预测帮助不大的子树。sklearn 中通过 ccp_alpha 参数控制。
代价复杂度公式
R_alpha(T) = R(T) + alpha * |T|
R(T) 是树的训练误差,|T| 是叶节点数量,alpha 是惩罚系数。alpha 越大,惩罚越重,最终树越小。
8. 决策树可视化
8.1 使用 sklearn plot_tree
8.2 使用 Graphviz
8.3 文本形式输出
9. 优点与缺点
1. 直观易解释,可可视化
2. 无需特征缩放/标准化
3. 能处理数值和类别特征
4. 能捕获非线性关系和特征交互
5. 可处理缺失值(CART)
6. 训练和预测速度快
7. 可直接输出特征重要性
1. 容易过拟合(特别是深树)
2. 不稳定:数据小变化可能导致完全不同的树
3. 对类别不平衡敏感
4. 贪心算法,不保证全局最优
5. 外推能力差(超出训练范围的值)
6. 处理高维稀疏数据效果差
7. 单棵树精度通常不如集成方法
10. 何时使用决策树
| 场景 | 推荐算法 | 原因 |
|---|---|---|
| 需要模型可解释性 | 决策树 | 可直接可视化决策规则,满足合规要求 |
| 快速建立基线模型 | 决策树 | 训练快、无需特征工程、不易调参 |
| 追求最高精度 | 随机森林 / XGBoost | 集成方法组合多棵树,精度更高更稳定 |
| 高维线性可分数据 | SVM / 逻辑回归 | 决策树对高维稀疏数据效果差 |
| 小数据集(<100样本) | 决策树(浅树) / KNN | 简单模型防过拟合 |
| 大规模数据(100万+) | XGBoost / LightGBM | 优化算法效率高,支持分布式 |
| 需要特征重要性分析 | 决策树 / 随机森林 | 内置特征重要性输出 |
| 数据含缺失值 | CART决策树 / XGBoost | 原生支持缺失值处理 |
经验法则:先用决策树理解数据和特征,再用随机森林或梯度提升树提升精度。决策树是集成方法的基础——理解了决策树,就理解了随机森林和XGBoost的核心。
11. 相关工具与指南
12. 常见问题 (FAQ)
sklearn 的 DecisionTreeClassifier 和 DecisionTreeRegressor 均基于优化的 CART 算法,仅构建二叉树。可通过 criterion 参数选择 'gini'(基尼系数)或 'entropy'(信息熵),但树结构始终是二叉的。如需 ID3 或 C4.5 的多叉树,需自行实现或使用其他库。
随机森林(Random Forest)是由多棵决策树组成的集成模型。它通过两种随机化机制减少过拟合:1) Bagging — 每棵树在随机采样的数据子集上训练;2) 特征随机 — 每次分裂只考虑随机子集的特征。最终预测是所有树的投票/平均结果。随机森林牺牲了可解释性,但大幅提升了精度和稳定性。
最常用的方法是交叉验证网格搜索:设定一个候选范围(如3到15),用 GridSearchCV 测试每个深度的验证集表现,选择测试集得分最高的深度。也可以使用 CCP 后剪枝自动找到最优复杂度。经验上,大多数场景下深度 4-8 就足够了。
不需要。决策树根据特征值的大小关系进行分裂(例如"身高 > 170cm?"),分裂结果不受特征尺度的影响。无论身高用厘米还是米表示,分裂效果完全一样。这是决策树和树集成方法相比 SVM、KNN 等算法的一大优势。
可以。sklearn 的决策树原生支持多分类(不需要 OvR/OvO),因为叶节点可以直接投票得到多个类别之一。也支持多输出(Multi-output),即同时预测多个目标变量——只需将 y 设为多列矩阵即可。