决策树又叫CART
CART = Classification And Regression Tree
根据某一个维度v和某一个域值d进行二分,这样得到的树就叫CART
scikit-learn创建决策树的方法就叫CART。
还有其它创建决策树的方法,例如ID3、C4.5、C5.0
创建好决策树好,预测一个样本的时间复杂度为O(log(m))
训练模型的时间复杂度为:O(nmlog(m))
是非参数学习算法,非常容易产生过拟合。
实际使用决策树时都要进行剪枝,1.降低复杂度2.解决过拟合
决策树使用不同的超参数的结果对比
生成数据
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
X, y = datasets.make_moons(noise=0.25, random_state=666)
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
plt.show()
使用默认参数训练
from sklearn.tree import DecisionTreeClassifier
# 默认情况下使用gini系数,且一直向下划分到所有gini系数都是0为止
dt_clf = DecisionTreeClassifier()
dt_clf.fit(X, y)
def plot_decision_boundary(model, axis):
x0, x1 = np.meshgrid(
np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1,1),
np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1,1)
)
X_new = np.c_[x0.ravel(), x1.ravel()]
y_predict = model.predict(X_new)
zz = y_predict.reshape(x0.shape)
from matplotlib.colors import ListedColormap
custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
plt.contourf(x0, x1, zz, cmap=custom_cmap)
plot_decision_boundary(dt_clf, axis=[-1.5, 2.5, -1.0, 1.5])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
plt.scatter(X[y==2,0],X[y==2,1])
plt.show()
典型的过拟合
使用不同的超参数的结果对比
default | max_depth=2 | min_samples_split=10 | min_samples_leaf=6 | max_leaf_nodes=4 |
---|---|---|---|---|
一直向下划分到所有gini系数都是0为止 | 最大深度 | 对于一个节点至少有多个样本才继续划分 | 叶子节点至少有几个样本 | 最多有多少个叶子 |
使用这些参数时:
- 避免欠拟合
- 使用风格搜索来组合这些参数