Python机器学习基础教程

Stella981
• 阅读 1032

前言

本系列教程基本就是摘抄《Python机器学习基础教程》中的例子内容。

为了便于跟踪和学习,本系列教程在Github上提供了jupyter notebook 版本:

Github仓库:https://github.com/Holy-Shine/Introduciton-2-ML-with-Python-notebook

系列教程总目录 Python机器学习基础教程

引子

先导入必要的包

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
import os
%matplotlib inline 

决策树是广泛用于分类和回归任务的模型。本质上,它从一层层的 if/else 问题中进行学 习,并得出结论。

这些问题类似于你在“20 Questions”游戏中可能会问的问题。想象一下,你想要区分下面这四种动物:熊、鹰、企鹅和海豚。你的目标是通过提出尽可能少的 if/else 问题来得到正确答案。你可能首先会问:这种动物有没有羽毛,这个问题会将可能的动物减少到只有两种。如果答案是“有”,你可以问下一个问题,帮你区分鹰和企鹅。例如,你可以问这种动物会不会飞。如果这种动物没有羽毛,那么可能是海豚或熊,所以你需要问一个问题来区分这两种动物——比如问这种动物有没有鳍。

这一系列问题可以表示为一棵决策树,如图 2-22 所示。

mglearn.plots.plot_animal_tree()
Python机器学习基础教程
图 2-22:区分几种动物的决策树

在这张图中,树的每个结点代表一个问题或一个包含答案的终结点(也叫叶结点)。树的边将问题的答案与将问的下一个问题连接起来。

用机器学习的语言来说就是,为了区分四类动物(鹰、企鹅、海豚和熊),我们利用三个特征(“有没有羽毛”“会不会飞”和“有没有鳍”)来构建一个模型。我们可以利用监督学习从数据中学习模型,而无需人为构建模型。

1. 构造决策树

我们在图 2-23 所示的二维分类数据集上构造决策树。这个数据集由 2 个半月形组成,每个类别都包含 50 个数据点。我们将这个数据集称为 two_moons 。

学习决策树,就是学习一系列 if/else 问题,使我们能够以最快的速度得到正确答案。在机器学习中,这些问题叫作测试(不要与测试集弄混,测试集是用来测试模型泛化性能的数据)。数据通常并不是像动物的例子那样具有二元特征(是 / 否)的形式,而是表示为连续特征,比如图 2-23 所示的二维数据集。用于连续数据的测试形式是:“特征 i 的值是否大于 a ?”

Python机器学习基础教程
图 2-23:用于构造决策树的 two\_moons 数据集

为了构造决策树,算法搜遍所有可能的测试,找出对目标变量来说信息量最大的那一个。图 2-24 展示了选出的第一个测试。将数据集在 x[1]=0.0596 处垂直划分可以得到最多信息,它在最大程度上将类别 0 中的点与类别 1 中的点进行区分。顶结点(也叫根结点)表示整个数据集,包含属于类别 0 的 50 个点和属于类别 1 的 50 个点。通过测试 x[1] <=0.0596 的真假来对数据集进行划分,在图中表示为一条黑线。如果测试结果为真,那么将这个点分配给左结点,左结点里包含属于类别 0 的 2 个点和属于类别 1 的 32 个点。否则将这个点分配给右结点,右结点里包含属于类别 0 的 48 个点和属于类别 1 的 18 个点。这两个结点对应于图 2-24 中的顶部区域和底部区域。尽管第一次划分已经对两个类别做了很好的区分,但底部区域仍包含属于类别 0 的点,顶部区域也仍包含属于类别 1 的点。我们可以在两个区域中重复寻找最佳测试的过程,从而构建出更准确的模型。图 2-25 展示了信息量最大的下一次划分,这次划分是基于 x[0] 做出的,分为左右两个区域。

Python机器学习基础教程
图 2-24:深度为 1 的树的决策边界(左)与相应的树(右)
Python机器学习基础教程
图 2-25:深度为 2 的树的决策边界(左)与相应的树(右)

这一递归过程生成一棵二元决策树,其中每个结点都包含一个测试。或者你可以将每个测试看成沿着一条轴对当前数据进行划分。这是一种将算法看作分层划分的观点。由于每个测试仅关注一个特征,所以划分后的区域边界始终与坐标轴平行。

对数据反复进行递归划分,直到划分后的每个区域(决策树的每个叶结点)只包含单一目标值(单一类别或单一回归值)。如果树中某个叶结点所包含数据点的目标值都相同,那么这个叶结点就是纯的(pure)。这个数据集的最终划分结果见图 2-26。

Python机器学习基础教程
图 2-26:深度为 9 的树的决策边界(左)与相应的树的一部分(右);完整的决策树非常大,很难可视化

想要对新数据点进行预测,首先要查看这个点位于特征空间划分的哪个区域,然后将该区域的多数目标值(如果是纯的叶结点,就是单一目标值)作为预测结果。从根结点开始对树进行遍历就可以找到这一区域,每一步向左还是向右取决于是否满足相应的测试。

决策树也可以用于回归任务,使用的方法完全相同。预测的方法是,基于每个结点的测试对树进行遍历,最终找到新数据点所属的叶结点。这一数据点的输出即为此叶结点中所有训练点的平均目标值。

2. 控制决策树的复杂度

通常来说,构造决策树直到所有叶结点都是纯的叶结点,这会导致模型非常复杂,并且对训练数据高度过拟合。纯叶结点的存在说明这棵树在训练集上的精度是 100%。训练集中的每个数据点都位于分类正确的叶结点中。在图 2-26 的左图中可以看出过拟合。你可以看到,在所有属于类别 0 的点中间有一块属于类别 1 的区域。另一方面,有一小条属于类别 0 的区域,包围着最右侧属于类别 0 的那个点。这并不是人们想象中决策边界的样子,这个决策边界过于关注远离同类别其他点的单个异常点。

防止过拟合有两种常见的策略:一种是及早停止树的生长,也叫预剪枝(pre-pruning);另一种是先构造树,但随后删除或折叠信息量很少的结点,也叫后剪枝(post-pruning)或剪枝(pruning)。预剪枝的限制条件可能包括限制树的最大深度、限制叶结点的最大数目,或者规定一个结点中数据点的最小数目来防止继续划分。

scikit-learn 的决策树在 DecisionTreeRegressor 类和 DecisionTreeClassifier 类中实现。scikit-learn 只实现了预剪枝,没有实现后剪枝。

我们在乳腺癌数据集上更详细地看一下预剪枝的效果。和前面一样,我们导入数据集并将其分为训练集和测试集。然后利用默认设置来构建模型,默认将树完全展开(树不断分支,直到所有叶结点都是纯的)。我们固定树的 random_state ,用于在内部解决平局问题:

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
    cancer.data, cancer.target, stratify=cancer.target, random_state=42)
tree = DecisionTreeClassifier(random_state=0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))

[out]:

Accuracy on training set: 1.000 Accuracy on test set: 0.937

不出所料,训练集上的精度是 100%,这是因为叶结点都是纯的,树的深度很大,足以完美地记住训练数据的所有标签。测试集精度比之前讲过的线性模型略低,线性模型的精度约为 95%。

如果我们不限制决策树的深度,它的深度和复杂度都可以变得特别大。因此,未剪枝的树容易过拟合,对新数据的泛化性能不佳。现在我们将预剪枝应用在决策树上,这可以在完美拟合训练数据之前阻止树的展开。一种选择是在到达一定深度后停止树的展开。这里我们设置 max_depth=4 ,这意味着只可以连续问 4 个问题(参见图 2-24 和图 2-26)。限制树的深度可以减少过拟合。这会降低训练集的精度,但可以提高测试集的精度:

tree=DecisionTreeClassifier(max_depth=4, random_state=0)
tree.fit(X_train, y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train, y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test, y_test)))

[out]:

Accuracy on training set: 0.988 Accuracy on test set: 0.951

3. 分析决策树

我们可以利用 tree 模块的 export_graphviz 函数来将树可视化。这个函数会生成一个 .dot 格式的文件,这是一种用于保存图形的文本文件格式。我们设置为结点添加颜色的选项,颜色表示每个结点中的多数类别,同时传入类别名称和特征名称,这样可以对树正确标记:

from sklearn.tree import export_graphviz
export_graphviz(tree, out_file="tree.dot", class_names=["malignant","benign"],
    feature_names=cancer.feature_names, impurity=False, filled=True)

我们可以利用 graphviz 模块读取这个文件并将其可视化(你也可以使用任何能够读取 .dot文件的程序),见图 2-27:

import graphviz
with open("tree.dot") as f:
    dot_graph = f.read()
    graphviz.Source(dot_graph)
Python机器学习基础教程
图 2-27:基于乳腺癌数据集构造的决策树的可视化

树的可视化有助于深入理解算法是如何进行预测的,也是易于向非专家解释的机器学习算法的优秀示例。不过,即使这里树的深度只有 4 层,也有点太大了。深度更大的树(深度为 10 并不罕见)更加难以理解。一种观察树的方法可能有用,就是找出大部分数据的实际路径。图 2-27 中每个结点的 samples 给出了该结点中的样本个数, values 给出的是每个类别的样本个数。观察 worst radius <= 16.795 分支右侧的子结点,我们发现它只包含8 个良性样本,但有 134 个恶性样本。树的这一侧的其余分支只是利用一些更精细的区别将这 8 个良性样本分离出来。在第一次划分右侧的 142 个样本中,几乎所有样本(132 个)最后都进入最右侧的叶结点中。

再来看一下根结点的左侧子结点,对于 worst radius > 16.795 ,我们得到 25 个恶性样本和 259 个良性样本。几乎所有良性样本最终都进入左数第二个叶结点中,大部分其他叶结点都只包含很少的样本。

4. 树的特征重要性

查看整个树可能非常费劲,除此之外,我还可以利用一些有用的属性来总结树的工作原理。其中最常用的是特征重要性(feature importance),它为每个特征对树的决策的重要性进行排序。对于每个特征来说,它都是一个介于 0 和 1 之间的数字,其中 0 表示“根本没用到”,1 表示“完美预测目标值”。特征重要性的求和始终为 1:

print("Feature importances:\n{}".format(tree.feature_importances_))

[out]:

Feature importances: [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.01 0.048 0. 0. 0.002 0. 0. 0. 0. 0. 0.727 0.046

    1. 0.014 0. 0.018 0.122 0.012 0. ]

我们可以将特征重要性可视化,与我们将线性模型的系数可视化的方法类似(图 2-28):

def plot_feature_importance_cancer(model):
    n_features = cancer.data.shape[1]
    plt.barh(range(n_features), model.feature_importances_, align='center')
    plt.yticks(np.arange(n_features), cancer.feature_names)
    plt.xlabel("Feature importance")
    plt.ylabel("Feature")

plot_feature_importance_cancer(tree)
Python机器学习基础教程
图 2-28:在乳腺癌数据集上学到的决策树的特征重要性

这里我们看到,顶部划分用到的特征(“worst radius”)是最重要的特征。这也证实了我们在分析树时的观察结论,即第一层划分已经将两个类别区分得很好。

但是,如果某个特征的 feature_importance_ 很小,并不能说明这个特征没有提供任何信息。这只能说明该特征没有被树选中,可能是因为另一个特征也包含了同样的信息。

与线性模型的系数不同,特征重要性始终为正数,也不能说明该特征对应哪个类别。特征重要性告诉我们“worst radius”(最大半径)特征很重要,但并没有告诉我们半径大表示样本是良性还是恶性。事实上,在特征和类别之间可能没有这样简单的关系,你可以在下面的例子中看出这一点(图 2-29 和图 2-30):

tree = mglearn.plots.plot_tree_not_monotone()
Python机器学习基础教程
图 2-29:一个二维数据集( y 轴上的特征与类别标签是非单调的关系)与决策树给出的决策边界
Python机器学习基础教程
图 2-30:从图 2-29 的数据中学到的决策树

该图显示的是有两个特征和两个类别的数据集。这里所有信息都包含在 X[1] 中,没有用到X[0] 。但 X[1] 和输出类别之间并不是单调关系,即我们不能这么说:“较大的 X[1] 对应类别 0,较小的 X[1] 对应类别 1”(反之亦然)。

虽然我们主要讨论的是用于分类的决策树,但对用于回归的决策树来说,所有内容都是类似的,在 DecisionTreeRegressor 中实现。回归树的用法和分析与分类树非常类似。但在将基于树的模型用于回归时,我们想要指出它的一个特殊性质。 DecisionTreeRegressor(以及其他所有基于树的回归模型)不能外推(extrapolate),也不能在训练数据范围之外进行预测。

我们利用计算机内存(RAM)历史价格的数据集来更详细地研究这一点。图 2-31 给出了这个数据集的图像,x 轴为日期,y 轴为那一年 1 兆字节(MB)RAM 的价格:

import pandas as pd
ram_prices = pd.read_csv("data/ram_price.csv")

plt.semilogy(ram_prices.date, ram_prices.price)
plt.xlabel('Year')
plt.ylabel("Pirce in $/Mbyte")
Python机器学习基础教程
图 2-31:用对数坐标绘制 RAM 价格的历史发展

注意 y 轴的对数刻度。在用对数坐标绘图时,二者的线性关系看起来非常好,所以预测应该相对比较容易,除了一些不平滑之处之外。

我们将利用 2000 年前的历史数据来预测 2000 年后的价格,只用日期作为特征。我们将对比两个简单的模型: DecisionTreeRegressor 和 LinearRegression 。我们对价格取对数,使得二者关系的线性相对更好。这对 DecisionTreeRegressor 不会产生什么影响,但对LinearRegression 的影响却很大(我们将在第 4 章中进一步讨论)。训练模型并做出预测之后,我们应用指数映射来做对数变换的逆运算。为了便于可视化,我们这里对整个数据集进行预测,但如果是为了定量评估,我们将只考虑测试数据集:

from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression
# 利用历史数据预测2000年后的价格
data_train = ram_prices[ram_prices.date < 2000]
data_test = ram_prices[ram_prices.date >= 2000]

# 基于日期来预测价格
X_train = data_train.date[:, np.newaxis]
# 我们利用对数变换得到数据和目标之间更简单的关系
y_train = np.log(data_train.price)

tree = DecisionTreeRegressor().fit(X_train, y_train)
linear_reg = LinearRegression().fit(X_train, y_train)

# 对所有数据进行预测
X_all = ram_prices.date[:, np.newaxis]

pred_tree = tree.predict(X_all)
pred_lr = linear_reg.predict(X_all)

# 对数变换逆运算
price_tree = np.exp(pred_tree)
price_lr = np.exp(pred_lr)

这里创建的图 2-32 将决策树和线性回归模型的预测结果与真实值进行对比:

plt.semilogy(data_train.date, data_train.price, label="Training data")
plt.semilogy(data_test.date, data_test.price, label="Test data")
plt.semilogy(ram_prices.date, price_tree, label="Tree prediction")
plt.semilogy(ram_prices.date, price_lr, label="Linear prediction")
plt.legend()
Python机器学习基础教程
图 2-32:线性模型和回归树对 RAM 价格数据的预测结果对比

两个模型之间的差异非常明显。线性模型用一条直线对数据做近似,这是我们所知道的。这条线对测试数据(2000 年后的价格)给出了相当好的预测,不过忽略了训练数据和测试数据中一些更细微的变化。与之相反,树模型完美预测了训练数据。由于我们没有限制树的复杂度,因此它记住了整个数据集。但是,一旦输入超出了模型训练数据的范围,模型就只能持续预测最后一个已知数据点。树不能在训练数据的范围之外生成“新的”响应。所有基于树的模型都有这个缺点。

5. 优点、缺点和参数

如前所述,控制决策树模型复杂度的参数是预剪枝参数,它在树完全展开之前停止树的构造。通常来说,选择一种预剪枝策略(设置 max_depth 、 max_leaf_nodes 或 min_samples_leaf )足以防止过拟合。

与前面讨论过的许多算法相比,决策树有两个优点:一是得到的模型很容易可视化,非专家也很容易理解(至少对于较小的树而言)二是算法完全不受数据缩放的影响。由于每个特征被单独处理,而且数据的划分也不依赖于缩放,因此决策树算法不需要特征预处理,比如归一化或标准化。特别是特征的尺度完全不一样时或者二元特征和连续特征同时存在时,决策树的效果很好。决策树的主要缺点在于,即使做了预剪枝,它也经常会过拟合,泛化性能很差。因此,在大多数应用中,往往使用下面介绍的集成方法来替代单棵决策树。

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
3个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Stella981 Stella981
3年前
Python之time模块的时间戳、时间字符串格式化与转换
Python处理时间和时间戳的内置模块就有time,和datetime两个,本文先说time模块。关于时间戳的几个概念时间戳,根据1970年1月1日00:00:00开始按秒计算的偏移量。时间元组(struct_time),包含9个元素。 time.struct_time(tm_y
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
3年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
9个月前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这