MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(二)

Stella981
• 阅读 997

版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

在前一篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)中,直接用python创建全连接神经网络模型进行深度学习训练,这样可以对神经网络有较为深刻的认识。

但是在实际应用中,一般都是采用各种深度学习框架来开展人工智能项目,以下就采用pytorch来实现前一篇文章中的全连接神经网络(784-300-10)。

  1 # -*- coding:utf-8 -*-
  2 
  3 u"""pytorch LineNet神经网络训练学习MINIST"""
  4 
  5 __author__ = 'zhengbiqing 460356155@qq.com'
  6 
  7 
  8 import torch as t
  9 import torchvision as tv
 10 import torch.nn as nn
 11 import torch.nn.functional as F
 12 import torchvision.transforms as transforms
 13 from torch.autograd import Variable
 14 import matplotlib.pyplot as plt
 15 import datetime
 16 
 17 
 18 #是否训练网络
 19 TRAIN = True
 20 
 21 #是否保存网络
 22 SAVE_PARA = False
 23 
 24 #学习率和训练次数
 25 LR = 0.05
 26 EPOCH = 10
 27 
 28 #训练每批次的样本数
 29 BATCH_SZ = 16
 30 
 31 #样本读取线程数
 32 WORKERS = 4
 33 
 34 #网络参赛保存文件名
 35 PARAS_FN = 'minist_linenet_params.pkl'
 36 
 37 #minist数据存放位置
 38 ROOT = '/home/zbq/pytorch/minist'
 39 
 40 
 41 #定义网络模型
 42 class LineNet(nn.Module):
 43     def __init__(self):
 44         super(LineNet, self).__init__()
 45 
 46         self.fc = nn.Sequential(
 47             nn.Linear(28*28, 300),
 48             nn.ReLU(),
 49             nn.Linear(300, 10)
 50         )
 51 
 52     def forward(self, x):
 53         #x是2维tensor,转换为1维向量
 54         x = x.view(x.size()[0], -1)
 55         x = self.fc(x)
 56         return x
 57 
 58 
 59 '''
 60 训练并测试网络
 61 net:网络模型
 62 train_data_load:训练数据集
 63 test_data_load:测试数据集
 64 epochs:训练迭代次数
 65 save:是否保存训练结果
 66 '''
 67 def net_train(net, train_data_load, test_data_load, epochs, save):
 68     start_time = datetime.datetime.now()
 69 
 70     loss_list = []
 71 
 72     for epoch in range(epochs):
 73         for i, data in enumerate(train_data_load, 0):
 74             img, label = data
 75             img, label = Variable(img), Variable(label)
 76             img, label = img.cuda(), label.cuda()
 77 
 78             optimizer.zero_grad()
 79 
 80             pre = net(img)
 81             loss = loss_func(pre, label)
 82             loss.backward()
 83 
 84             optimizer.step()
 85 
 86             #显示损失函数值的变化
 87             loss_data = loss.data.item()
 88             if i % 1000 == 999:
 89                 print('epoch:{epoch} i:{i} loss:{loss}'.format(epoch=epoch, i=i, loss=loss_data))
 90 
 91             if i % 100 == 99:
 92                 loss_list.append(loss_data)
 93 
 94         # 每个epoch结束后用测试集检查识别准确度
 95         net_test(epoch, test_data_load)
 96 
 97     print('MINIST pytorch LineNet Train: EPOCH:{epochs}, BATCH_SZ:{batch_sz}, LR:{lr}'.format(epochs=epochs, batch_sz=BATCH_SZ, lr=LR))
 98     print('train spend time: ', datetime.datetime.now() - start_time)
 99 
100     if save:
101         t.save(net.state_dict(), PARAS_FN)
102 
103     #显示目标函数值的变化曲线
104     plt.plot(loss_list)
105     plt.show()
106 
107 
108 '''
109 用测试集检查准确率
110 '''
111 def net_test(epoch, test_data_load):
112     ok = 0
113 
114     for i, data in enumerate(test_data_load):
115         img, label = data
116         img, label = Variable(img), Variable(label)
117         img, label = img.cuda(), label.cuda()
118 
119         outs = net(img)
120         _, pre = t.max(outs.data, 1)
121         ok += (pre == label).sum()
122 
123     acc = ok.item() * 100 / (len(test_data_load) * BATCH_SZ)
124 
125     print('EPOCH:{epoch}, ACC:{acc}\n'.format(epoch=epoch, acc=acc))
126 
127 
128 #图像数值转换,ToTensor源码注释
129 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
130     Converts a PIL Image or numpy.ndarray (H x W x C) in the range
131     [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
132     """
133 #归一化,把[0.0, 1.0]变换为[-1,1], ([0, 1] - 0.5) / 0.5 = [-1, 1]
134 transform = tv.transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
135 
136 #定义数据集
137 train_data = tv.datasets.MNIST(root=ROOT, train=True, download=True, transform=transform)
138 test_data = tv.datasets.MNIST(root=ROOT, train=False, download=False, transform=transform)
139 
140 train_load = t.utils.data.DataLoader(train_data, batch_size=BATCH_SZ, shuffle=True, num_workers=WORKERS)
141 test_load = t.utils.data.DataLoader(test_data, batch_size=BATCH_SZ, shuffle=False, num_workers=WORKERS)
142 
143 print('train data num:', len(train_data), ', test data num:', len(test_data))
144 
145 
146 net = LineNet()
147 net.cuda()
148 
149 loss_func = nn.CrossEntropyLoss()
150 optimizer = t.optim.SGD(net.parameters(), lr=LR)
151 
152 if TRAIN:
153     net_train(net, train_load, test_load, EPOCH, SAVE_PARA)
154 else:
155     net.load_state_dict(t.load(PARAS_FN))
156     net_test(0, test_load)

网络训练结果准确率基本在97%~98%,和前一篇MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)相同网络结构的全连接神经网络相当,但是因为这里采用GPU运算,训练时间降低到1/8。

此外,借助pytorch,代码更简单。

运行结果如下:

train data num: 60000 , test data num: 10000
epoch:0 i:999 loss:0.3457891643047333
epoch:0 i:1999 loss:0.09639787673950195
epoch:0 i:2999 loss:0.27898865938186646
EPOCH:0, ACC:94.84

epoch:1 i:999 loss:0.33745211362838745
epoch:1 i:1999 loss:0.11106520891189575
epoch:1 i:2999 loss:0.21725007891654968
EPOCH:1, ACC:96.42

epoch:2 i:999 loss:0.3825737535953522
epoch:2 i:1999 loss:0.02866300940513611
epoch:2 i:2999 loss:0.11832481622695923
EPOCH:2, ACC:96.77

epoch:3 i:999 loss:0.11886310577392578
epoch:3 i:1999 loss:0.012149035930633545
epoch:3 i:2999 loss:0.030409961938858032
EPOCH:3, ACC:97.2

epoch:4 i:999 loss:0.008915185928344727
epoch:4 i:1999 loss:0.008089780807495117
epoch:4 i:2999 loss:0.0005310177803039551
EPOCH:4, ACC:97.6

epoch:5 i:999 loss:0.02993696928024292
epoch:5 i:1999 loss:0.01784616708755493
epoch:5 i:2999 loss:0.10544028878211975
EPOCH:5, ACC:97.6

epoch:6 i:999 loss:0.008486062288284302
epoch:6 i:1999 loss:0.0334945023059845
epoch:6 i:2999 loss:0.00291365385055542
EPOCH:6, ACC:97.37

epoch:7 i:999 loss:0.0062919557094573975
epoch:7 i:1999 loss:0.0003241896629333496
epoch:7 i:2999 loss:0.0006818175315856934
EPOCH:7, ACC:97.23

epoch:8 i:999 loss:0.0007421970367431641
epoch:8 i:1999 loss:0.005641639232635498
epoch:8 i:2999 loss:0.005949795246124268
EPOCH:8, ACC:97.7

epoch:9 i:999 loss:0.024028539657592773
epoch:9 i:1999 loss:0.005388796329498291
epoch:9 i:2999 loss:0.0029097795486450195
EPOCH:9, ACC:97.39

MINIST pytorch LineNet Train: EPOCH:10, BATCH_SZ:16, LR:0.05
train spend time:  0:00:43.183836

损失函数值变化曲线为:

MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(二)

点赞
收藏
评论区
推荐文章
不是海碗 不是海碗
1年前
银行卡识别OCR:解放金融业务处理效率的黑科技!
银行卡识别OCR是通过光学字符识别(OCR)技术实现的。它基于深度学习算法,通过卷积神经网络(CNN)对银行卡图片进行特征提取和分析,从而识别出银行卡上的各个字段。
Wesley13 Wesley13
3年前
DAO与DAL的区别
版权声明:本文为博主原创文章,遵循CC4.0BYSA(https://www.oschina.net/action/GoToLink?urlhttp%3A%2F%2Fcreativecommons.org%2Flicenses%2Fbysa%2F4.0%2F)版权协议,转载请附上原文出处链接和本声明。本文链接:https://blo
Wesley13 Wesley13
3年前
GO富集分析示例【华为云技术分享】
版权声明:本文为博主原创文章,遵循CC4.0BYSA(https://www.oschina.net/action/GoToLink?urlhttp%3A%2F%2Fcreativecommons.org%2Flicenses%2Fbysa%2F4.0%2F)版权协议,转载请附上原文出处链接和本声明。本文链接:https://blo
Stella981 Stella981
3年前
Docker之Mysql安装及配置
原文:Docker之Mysql安装及配置(https://www.oschina.net/action/GoToLink?urlhttps%3A%2F%2Fblog.csdn.net%2Fzhaobw831%2Farticle%2Fdetails%2F80141633)版权声明:本文为博主原创文章,未经博主允许不得转载。https://blog
Stella981 Stella981
3年前
Linux下源码包安装Swoole及基本使用 转
版权声明:本文为博主原创文章,遵循CC4.0BYSA(https://www.oschina.net/action/GoToLink?urlhttp%3A%2F%2Fcreativecommons.org%2Flicenses%2Fbysa%2F4.0%2F)版权协议,转载请附上原文出处链接和本声明。本文链接:https://blo
Wesley13 Wesley13
3年前
Github项目解析(九)
版权声明:本文为博主原创文章,未经博主允许不得转载。转载请标明出处:一片枫叶的专栏(https://www.oschina.net/action/GoToLink?urlhttp%3A%2F%2Fblog.csdn.net%2Fqq_23547831%2Farticle%2Fdetails%2F51821159)上一篇文章中我们讲解了在Ac
深度学习与图神经网络学习分享:CNN经典网络之-ResNet
深度学习与图神经网络学习分享:CNN经典网络之ResNetresnet又叫深度残差网络图像识别准确率很高,主要作者是国人哦深度网络的退化问题深度网络难以训练,梯度消失,梯度爆炸,老生常谈,不多说!深度网络的退化问题(htt
人工智能人才培养
No.1第一天一、机器学习简介与经典机器学习算法介绍什么是机器学习?机器学习框架与基本组成机器学习的训练步骤机器学习问题的分类经典机器学习算法介绍章节目标:机器学习是人工智能的重要技术之一,详细了解机器学习的原理、机制和方法,为学习深度学习与迁移学习打下坚实的基础。二、深度学习简介与经典网络结构介绍神经网络简介神经网络组件简介神经网络训练方法卷积神经网络介
四儿 四儿
1年前
深度学习在语音识别中的应用及挑战
一、引言随着深度学习技术的快速发展,其在语音识别领域的应用也日益广泛。深度学习技术可以有效地提高语音识别的精度和效率,并且被广泛应用于各种应用场景。本文将探讨深度学习在语音识别中的应用及所面临的挑战。二、深度学习在语音识别中的应用1.基于深度神经网络的语音