Python数据科学:神经网络

Stella981
• 阅读 744

Python数据科学:神经网络

(Artificial Neural Network,ANN)人工神经网络模型,以数学和物理的方法对人脑神经网络进行简化、抽象和模拟。

本次只是一个简单的神经网络入门,涉及神经元模型和BP神经网络。

这里简单了解一下机器学习的三要素,分别是模型、策略与算法。

模型包括非随机效应部分(被解释变量和解释变量之间的关系,多为函数关系)和随机效应部分(扰动项)。

策略是指如何设定最优化的目标函数,常见的目标函数有线性回归的残差平方和、逻辑回归的似然函数、SVM中的合页函数等。

算法是对目标函数求参的方法,比如通过求导的方法计算,或者使用数值计算领域的算法求解。

其中神经网络就是采用数值算法求解参数,这就意味着每次计算得到的模型参数都会是不同的。

/ 01 / 神经网络

01 神经元模型

神经网络中最基本的成分是神经元模型。

每个神经元都是一个多输入单输出的信息处理单元,输入信号通过带权重的连接传递,和阈值对比后得到总输入值,再通过激活函数的处理产生单个输出

神经元的输出,是对激活函数套用输入加权和的结果。

神经元的激活函数使得神经元具有不同的信息处理特性,反映了神经元输出与其激活状态之间的关系。

本次涉及到的激活函数有阈值函数(阶跃函数)、sigmoid函数(S型函数)。

02 单层感知器

感知器是一种具有单层计算单元的神经网络,只能用来解决线性可分的二分类问题。

无法运用到多层感知器中,无法确定隐藏层的期望输出。

它的结构类似之前的神经元模型。

激活函数采用单极性(或双极性)阈值函数。

03 BP神经网络

采用误差反向传播算法(有监督学习算法)训练的多层神经网络称为BP神经网络。

属于多层前馈型神经网络,模型的学习过程由信号的正向传播误差反向传播两个过程组成。

进行正向传播时信号从输入层计算各层加权和,经由各隐层最终传递到输出层,得到输出结果,比较输出结果与期望结果(监督信号),得到输出误差。

误差反向传播是依照梯度下降算法将误差沿着隐藏层到输入层逐层反向传播,将误差分摊给各层的所有单元,从而得到各个单元的误差信号(学习信号),据此修改各单元权值。

这两个信号传播过程不断循环以更新权值,最终根据判定条件判断是否结束循环。

其网络结构普遍为单隐层网络,包括输入层隐层输出层

激活函数多采用sigmoid函数或线性函数,这里隐层和输出层均采用sigmoid函数。

/ 02/ Python实现

神经网络在有明确的训练样本后,网络的输入层结点数(解释变量个数)和输出层结点数(被解释变量的个数)便已确定。

需要考虑的则是隐含层的个数和每个隐含层的结点个数。

下面利用书中的数据进行实战一波,一份移动离网数据。

移动通讯用户消费特征数据,目标字段为是否流失,具有两个分类水平(是与否)。

自变量包含了用户的基本信息、消费的产品信息以及用户的消费特征。

读取数据。

import pandas as pdfrom sklearn import metricsimport matplotlib.pyplot as pltfrom sklearn.preprocessing import MinMaxScalerfrom sklearn.neural_network import MLPClassifierfrom sklearn.model_selection import GridSearchCVfrom sklearn.model_selection import train_test_split# 设置最大显示行数pd.set_option('display.max_rows', 10)# 设置最大显示列数pd.set_option('display.max_columns', 10)# 设置显示宽度为1000,这样就不会在IDE中换行了pd.set_option('display.width', 1000)# 读取数据,skipinitialspace:忽略分隔符后的空白churn = pd.read_csv('telecom_churn.csv', skipinitialspace=True)print(churn)

输出数据概况,包含3000多个用户数据。

Python数据科学:神经网络

使用scikit-learn中的函数将数据集划分为训练集和测试集。

# 选取自变量数据data = churn.iloc[:, 2:]# 选取因变量数据target = churn['churn']# 使用scikit-learn将数据集划分为训练集和测试集train_data, test_data, train_target, test_target = train_test_split(data, target, test_size=0.4, train_size=0.6, random_state=1234)

神经网络需要对数据进行极值标准化

需要对连续变量进行极值标准化,分类变量需要转变为虚拟变量。

其中多分类名义变量必须转变为虚拟变量,而等级变量和二分类变量则可以选择不转变,当做连续变量处理即可。

本次数据中,教育等级和套餐类型是等级变量,性别等变量为二分类变量,这些都可以作为连续变量进行处理。

这也就意味着本次的数据集中不存在多分类名义变量,都可作为连续变量进行处理。

# 极值标准化处理scaler = MinMaxScaler()scaler.fit(train_data)scaled_train_data = scaler.transform(train_data)scaler_test_data = scaler.transform(test_data)

建立多层感知器模型。

# 设置多层感知器对应的模型mlp = MLPClassifier(hidden_layer_sizes=(10,), activation='logistic', alpha=0.1, max_iter=1000)# 对训练集进行模型训练mlp.fit(scaled_train_data, train_target)# 输出神经网络模型信息print(mlp)

输出模型信息如下。

Python数据科学:神经网络

接下来使用经过训练集训练的模型,对训练集及测试集进行预测。

# 使用模型进行预测train_predict = mlp.predict(scaled_train_data)test_predict = mlp.predict(scaler_test_data)

输出预测概率,用户流失的概率。

# 输出模型预测概率(为1的情况)train_proba = mlp.predict_proba(scaled_train_data)[:, 1]test_proba = mlp.predict_proba(scaler_test_data)[:, 1]

对模型进行评估,输出评估数据。

# 根据预测信息输出模型评估结果print(metrics.confusion_matrix(test_target, test_predict, labels=[0, 1]))print(metrics.classification_report(test_target, test_predict))

输出如下。

Python数据科学:神经网络

模型对流失用户的f1-score(精确率和召回率的调和平均数)值为0.81,效果不错。

此外对流失用户的灵敏度recall为0.83,模型能识别出83%的流失用户,说明模型识别流失用户的能力还可以。

输出模型预测的平均准确度。

# 使用指定数据集输出模型预测的平均准确度print(mlp.score(scaler_test_data, test_target))# 输出值为0.8282828282828283

平均准确度值为0.8282。

计算模型的ROC下面积。

# 绘制ROC曲线fpr_test, tpr_test, th_test = metrics.roc_curve(test_target, test_proba)fpr_train, tpr_train, th_train = metrics.roc_curve(train_target, train_proba)plt.figure(figsize=[3, 3])plt.plot(fpr_test, tpr_test, 'b--')plt.plot(fpr_train, tpr_train, 'r-')plt.title('ROC curve')plt.show()# 计算AUC值print(metrics.roc_auc_score(test_target, test_proba))# 输出值为0.9149632415075206

ROC曲线图如下。

Python数据科学:神经网络

训练集和测试集的曲线很接近,没有过拟合现象。

AUC值为0.9149,说明模型效果非常好。

对模型进行最优参数搜索,并且对最优参数下的模型进行训练。

# 使用GridSearchCV进行最优参数搜索param_grid = {    # 模型隐层数量    'hidden_layer_sizes': [(10, ), (15, ), (20, ), (5, 5)],    # 激活函数    'activation': ['logistic', 'tanh', 'relu'],    # 正则化系数    'alpha': [0.001, 0.01, 0.1, 0.2, 0.4, 1, 10]}mlp = MLPClassifier(max_iter=1000)# 选择roc_auc作为评判标准,4折交叉验证,n_jobs=-1使用多核CPU的全部线程gcv = GridSearchCV(estimator=mlp, param_grid=param_grid,                   scoring='roc_auc', cv=4, n_jobs=-1)gcv.fit(scaled_train_data, train_target)

输出最优参数的模型的情况。

# 输出最优参数下模型的得分print(gcv.best_score_)# 输出值为0.9258018987136855# 输出最优参数下模型的参数print(gcv.best_params_)# 输出参数值为{'alpha': 0.01, 'activation': 'tanh', 'hidden_layer_sizes': (5, 5)}# 使用指定数据集输出最优模型预测的平均准确度print(gcv.score(scaler_test_data, test_target))# 输出值为0.9169384823390232

模型的roc_auc最高得分为0.92,即该模型下的ROC曲线下面积为0.92。

较之前的0.9149,提高了一点点。

模型的最优参数,激活函数为relu类型,alpha为0.01,隐藏层节点数为15个。

模型的预测平均准确率为0.9169,较之前的0.8282,提高了不少。

相关资料获取,请点击阅读原文。

推荐阅读

Python数据科学:神经网络

Python数据科学:神经网络

Python数据科学:神经网络

···  END  ···

Python数据科学:神经网络

欢迎大家点赞,留言,转发,转载,****感谢大家的相伴与支持

想加入Python学习群请在后台回复【入群

万水千山总是情,点个【在看】行不行

本文分享自微信公众号 - IT共享之家(info-share)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
5个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Stella981 Stella981
3年前
Python3:sqlalchemy对mysql数据库操作,非sql语句
Python3:sqlalchemy对mysql数据库操作,非sql语句python3authorlizmdatetime2018020110:00:00coding:utf8'''
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
3年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
11个月前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这