版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com
在前一篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)中,直接用python创建全连接神经网络模型进行深度学习训练,这样可以对神经网络有较为深刻的认识。
但是在实际应用中,一般都是采用各种深度学习框架来开展人工智能项目,以下就采用pytorch来实现前一篇文章中的全连接神经网络(784-300-10)。
1 # -*- coding:utf-8 -*-
2
3 u"""pytorch LineNet神经网络训练学习MINIST"""
4
5 __author__ = 'zhengbiqing 460356155@qq.com'
6
7
8 import torch as t
9 import torchvision as tv
10 import torch.nn as nn
11 import torch.nn.functional as F
12 import torchvision.transforms as transforms
13 from torch.autograd import Variable
14 import matplotlib.pyplot as plt
15 import datetime
16
17
18 #是否训练网络
19 TRAIN = True
20
21 #是否保存网络
22 SAVE_PARA = False
23
24 #学习率和训练次数
25 LR = 0.05
26 EPOCH = 10
27
28 #训练每批次的样本数
29 BATCH_SZ = 16
30
31 #样本读取线程数
32 WORKERS = 4
33
34 #网络参赛保存文件名
35 PARAS_FN = 'minist_linenet_params.pkl'
36
37 #minist数据存放位置
38 ROOT = '/home/zbq/pytorch/minist'
39
40
41 #定义网络模型
42 class LineNet(nn.Module):
43 def __init__(self):
44 super(LineNet, self).__init__()
45
46 self.fc = nn.Sequential(
47 nn.Linear(28*28, 300),
48 nn.ReLU(),
49 nn.Linear(300, 10)
50 )
51
52 def forward(self, x):
53 #x是2维tensor,转换为1维向量
54 x = x.view(x.size()[0], -1)
55 x = self.fc(x)
56 return x
57
58
59 '''
60 训练并测试网络
61 net:网络模型
62 train_data_load:训练数据集
63 test_data_load:测试数据集
64 epochs:训练迭代次数
65 save:是否保存训练结果
66 '''
67 def net_train(net, train_data_load, test_data_load, epochs, save):
68 start_time = datetime.datetime.now()
69
70 loss_list = []
71
72 for epoch in range(epochs):
73 for i, data in enumerate(train_data_load, 0):
74 img, label = data
75 img, label = Variable(img), Variable(label)
76 img, label = img.cuda(), label.cuda()
77
78 optimizer.zero_grad()
79
80 pre = net(img)
81 loss = loss_func(pre, label)
82 loss.backward()
83
84 optimizer.step()
85
86 #显示损失函数值的变化
87 loss_data = loss.data.item()
88 if i % 1000 == 999:
89 print('epoch:{epoch} i:{i} loss:{loss}'.format(epoch=epoch, i=i, loss=loss_data))
90
91 if i % 100 == 99:
92 loss_list.append(loss_data)
93
94 # 每个epoch结束后用测试集检查识别准确度
95 net_test(epoch, test_data_load)
96
97 print('MINIST pytorch LineNet Train: EPOCH:{epochs}, BATCH_SZ:{batch_sz}, LR:{lr}'.format(epochs=epochs, batch_sz=BATCH_SZ, lr=LR))
98 print('train spend time: ', datetime.datetime.now() - start_time)
99
100 if save:
101 t.save(net.state_dict(), PARAS_FN)
102
103 #显示目标函数值的变化曲线
104 plt.plot(loss_list)
105 plt.show()
106
107
108 '''
109 用测试集检查准确率
110 '''
111 def net_test(epoch, test_data_load):
112 ok = 0
113
114 for i, data in enumerate(test_data_load):
115 img, label = data
116 img, label = Variable(img), Variable(label)
117 img, label = img.cuda(), label.cuda()
118
119 outs = net(img)
120 _, pre = t.max(outs.data, 1)
121 ok += (pre == label).sum()
122
123 acc = ok.item() * 100 / (len(test_data_load) * BATCH_SZ)
124
125 print('EPOCH:{epoch}, ACC:{acc}\n'.format(epoch=epoch, acc=acc))
126
127
128 #图像数值转换,ToTensor源码注释
129 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
130 Converts a PIL Image or numpy.ndarray (H x W x C) in the range
131 [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
132 """
133 #归一化,把[0.0, 1.0]变换为[-1,1], ([0, 1] - 0.5) / 0.5 = [-1, 1]
134 transform = tv.transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
135
136 #定义数据集
137 train_data = tv.datasets.MNIST(root=ROOT, train=True, download=True, transform=transform)
138 test_data = tv.datasets.MNIST(root=ROOT, train=False, download=False, transform=transform)
139
140 train_load = t.utils.data.DataLoader(train_data, batch_size=BATCH_SZ, shuffle=True, num_workers=WORKERS)
141 test_load = t.utils.data.DataLoader(test_data, batch_size=BATCH_SZ, shuffle=False, num_workers=WORKERS)
142
143 print('train data num:', len(train_data), ', test data num:', len(test_data))
144
145
146 net = LineNet()
147 net.cuda()
148
149 loss_func = nn.CrossEntropyLoss()
150 optimizer = t.optim.SGD(net.parameters(), lr=LR)
151
152 if TRAIN:
153 net_train(net, train_load, test_load, EPOCH, SAVE_PARA)
154 else:
155 net.load_state_dict(t.load(PARAS_FN))
156 net_test(0, test_load)
网络训练结果准确率基本在97%~98%,和前一篇MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)相同网络结构的全连接神经网络相当,但是因为这里采用GPU运算,训练时间降低到1/8。
此外,借助pytorch,代码更简单。
运行结果如下:
train data num: 60000 , test data num: 10000
epoch:0 i:999 loss:0.3457891643047333
epoch:0 i:1999 loss:0.09639787673950195
epoch:0 i:2999 loss:0.27898865938186646
EPOCH:0, ACC:94.84
epoch:1 i:999 loss:0.33745211362838745
epoch:1 i:1999 loss:0.11106520891189575
epoch:1 i:2999 loss:0.21725007891654968
EPOCH:1, ACC:96.42
epoch:2 i:999 loss:0.3825737535953522
epoch:2 i:1999 loss:0.02866300940513611
epoch:2 i:2999 loss:0.11832481622695923
EPOCH:2, ACC:96.77
epoch:3 i:999 loss:0.11886310577392578
epoch:3 i:1999 loss:0.012149035930633545
epoch:3 i:2999 loss:0.030409961938858032
EPOCH:3, ACC:97.2
epoch:4 i:999 loss:0.008915185928344727
epoch:4 i:1999 loss:0.008089780807495117
epoch:4 i:2999 loss:0.0005310177803039551
EPOCH:4, ACC:97.6
epoch:5 i:999 loss:0.02993696928024292
epoch:5 i:1999 loss:0.01784616708755493
epoch:5 i:2999 loss:0.10544028878211975
EPOCH:5, ACC:97.6
epoch:6 i:999 loss:0.008486062288284302
epoch:6 i:1999 loss:0.0334945023059845
epoch:6 i:2999 loss:0.00291365385055542
EPOCH:6, ACC:97.37
epoch:7 i:999 loss:0.0062919557094573975
epoch:7 i:1999 loss:0.0003241896629333496
epoch:7 i:2999 loss:0.0006818175315856934
EPOCH:7, ACC:97.23
epoch:8 i:999 loss:0.0007421970367431641
epoch:8 i:1999 loss:0.005641639232635498
epoch:8 i:2999 loss:0.005949795246124268
EPOCH:8, ACC:97.7
epoch:9 i:999 loss:0.024028539657592773
epoch:9 i:1999 loss:0.005388796329498291
epoch:9 i:2999 loss:0.0029097795486450195
EPOCH:9, ACC:97.39
MINIST pytorch LineNet Train: EPOCH:10, BATCH_SZ:16, LR:0.05
train spend time: 0:00:43.183836
损失函数值变化曲线为: