决策树又叫CART
CART = Classification And Regression Tree
根据某一个维度v和某一个域值d进行二分,这样得到的树就叫CART
scikit-learn创建决策树的方法就叫CART。
还有其它创建决策树的方法,例如ID3、C4.5、C5.0

创建好决策树好,预测一个样本的时间复杂度为O(log(m))
训练模型的时间复杂度为:O(nmlog(m))
是非参数学习算法,非常容易产生过拟合。
实际使用决策树时都要进行剪枝,1.降低复杂度2.解决过拟合

决策树使用不同的超参数的结果对比

生成数据

import numpy as np import matplotlib.pyplot as plt from sklearn import datasets X, y = datasets.make_moons(noise=0.25, random_state=666) plt.scatter(X[y==0,0],X[y==0,1]) plt.scatter(X[y==1,0],X[y==1,1]) plt.show()

使用默认参数训练

from sklearn.tree import DecisionTreeClassifier # 默认情况下使用gini系数,且一直向下划分到所有gini系数都是0为止 dt_clf = DecisionTreeClassifier() dt_clf.fit(X, y) def plot_decision_boundary(model, axis): x0, x1 = np.meshgrid( np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1,1), np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1,1) ) X_new = np.c_[x0.ravel(), x1.ravel()] y_predict = model.predict(X_new) zz = y_predict.reshape(x0.shape) from matplotlib.colors import ListedColormap custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9']) plt.contourf(x0, x1, zz, cmap=custom_cmap) plot_decision_boundary(dt_clf, axis=[-1.5, 2.5, -1.0, 1.5]) plt.scatter(X[y==0,0],X[y==0,1]) plt.scatter(X[y==1,0],X[y==1,1]) plt.scatter(X[y==2,0],X[y==2,1]) plt.show()

典型的过拟合

使用不同的超参数的结果对比

defaultmax_depth=2min_samples_split=10min_samples_leaf=6max_leaf_nodes=4
一直向下划分到所有gini系数都是0为止最大深度对于一个节点至少有多个样本才继续划分叶子节点至少有几个样本最多有多少个叶子

使用这些参数时:

  1. 避免欠拟合
  2. 使用风格搜索来组合这些参数