MXNET:深度学习计算

Wesley13
• 阅读 815

我们将深入讲解模型参数的访问和初始化,以及如何在多个层之间共享同一份参数。 之前我们一直在使用默认的初始函数,net.initialize()。

from mxnet import init, nd
from mxnet.gluon import nn

net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
net.initialize()

x = nd.random.uniform(shape=(2,20))
y = net(x)

这里我们从 MXNet 中导入了 init 这个包,它包含了多种模型初始化方法。

访问模型参数

我们知道可以通过 [] 来访问 Sequential 类构造出来的网络的特定层。对于带有模型参数的层,我们可以通过 Block 类的 params 属性来得到它包含的所有参数。例如我们查看隐藏层的参数:

net[0].params
# output
dense0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
)

我们得到了一个由参数名称映射到参数实例的字典。第一个参数的名称为 dense0_weight,它由 net[0] 的名称(dense0_)和自己的变量名(weight)组成。而且可以看到它参数的形状为 (256, 20),且数据类型为 32 位浮点数。

为了访问特定参数,我们既可以通过名字来访问字典里的元素,也可以直接使用它的变量名。下面两种方法是等价的,但通常后者的代码可读性更好。

net[0].params['dense0_weight'], net[0].weight

Gluon 里参数类型为 Parameter 类,其包含参数权重和它对应的梯度,它们可以分别通过 data 和 grad 函数来访问。因为我们随机初始化了权重,所以它是一个由随机数组成的形状为 (256, 20) 的 NDArray.

net[0].weight.data()
# output
[[ 0.06700657 -0.00369488  0.0418822  ..., -0.05517294 -0.01194733
  -0.00369594]
 ...,
 [ 0.00297424 -0.0281784  -0.06881659 ..., -0.04047417  0.00457048
   0.05696651]]
<NDArray 256x20 @cpu(0)>

梯度的形状跟权重一样。但由于我们还没有进行反向传播计算,所以它的值全为 0.

net[0].weight.grad()
# output
[[ 0.  0.  0. ...,  0.  0.  0.]
 ...,
 [ 0.  0.  0. ...,  0.  0.  0.]]
<NDArray 256x20 @cpu(0)>

类似我们可以访问其他的层的参数。例如输出层的偏差权重:

net[1].bias.data()

最后,我们可以 collect_params 函数来获取 net 实例所有嵌套(例如通过 add 函数嵌套)的层所包含的所有参数。它返回的同样是一个参数名称到参数实例的字典。

net.collect_params()
# output
sequential0_ (
  Parameter dense0_weight (shape=(256, 20), dtype=float32)
  Parameter dense0_bias (shape=(256,), dtype=float32)
  Parameter dense1_weight (shape=(10, 256), dtype=float32)
  Parameter dense1_bias (shape=(10,), dtype=float32)
)

初始化模型参数

当使用默认的模型初始化,Gluon 会将权重参数元素初始化为 [-0.07, 0.07] 之间均匀分布的随机数,偏差参数则全为 0. 但经常我们需要使用其他的方法来初始话权重,MXNet 的 init 模块里提供了多种预设的初始化方法。例如下面例子我们将权重参数初始化成均值为 0,标准差为 0.01 的正态分布随机数。

# 非首次对模型初始化需要指定 force_reinit。
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]

如果想只对某个特定参数进行初始化,我们可以调用 Paramter 类的 initialize 函数,它的使用跟 Block 类提供的一致。下例中我们对第一个隐藏层的权重使用 Xavier 初始化方法。

net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[0].weight.data()[0]

自定义初始化方法

有时候我们需要的初始化方法并没有在 init 模块中提供。这时,我们可以实现一个 Initializer 类的子类使得我们可以跟前面使用 init.Normal 那样使用它。通常,我们只需要实现 _init_weight 这个函数,将其传入的 NDArray 修改成需要的内容。下面例子里我们把权重初始化成 [-10,-5] 和 [5,10] 两个区间里均匀分布的随机数。

class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        data[:] = nd.random.uniform(low=-10, high=10, shape=data.shape)
        data *= data.abs() >= 5

net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[0]

此外,我们还可以通过 Parameter 类的 set_data 函数来直接改写模型参数。例如下例中我们将隐藏层参数在现有的基础上加 1。

net[0].weight.set_data(net[0].weight.data() + 1)
net[0].weight.data()[0]

共享模型参数

在有些情况下,我们希望在多个层之间共享模型参数。我们在 “模型构造” 一节看到了如何在 Block 类里 forward 函数里多次调用同一个类来完成。

这里将介绍另外一个方法,它在构造层的时候指定使用特定的参数。如果不同层使用同一份参数,那么它们不管是在前向计算还是反向传播时都会共享共同的参数。

我们让模型的第二隐藏层和第三隐藏层共享模型参数。

net = nn.Sequential()
shared = nn.Dense(8, activation='relu')
net.add(nn.Dense(8, activation='relu'),
        shared,
        nn.Dense(8, activation='relu', params=shared.params),
        nn.Dense(10))
net.initialize()

x = nd.random.uniform(shape=(2,20))
net(x)

net[1].weight.data()[0] == net[2].weight.data()[0]

# output
[ 1.  1.  1.  1.  1.  1.  1.  1.]
<NDArray 8 @cpu(0)>

我们在构造第三隐藏层时通过 params 来指定它使用第二隐藏层的参数。由于模型参数里包含了梯度,所以在反向传播计算时,第二隐藏层和第三隐藏层的梯度都会被累加在 shared.params.grad() 里。

延后的初始

注意到前面使用 Gluon 的章节里,我们在创建全连接层时都没有指定输入大小。例如在一直使用的多层感知机例子里,我们创建了输出大小为 256 的隐藏层。但是当在调用 initialize 函数的时候,我们并不知道这个层的参数到底有多大,因为它的输入大小仍然是未知。

只有在当我们将形状是 (2,20) 的 x 输入进网络时,我们这时候才知道这一层的参数大小应该是 (256,20)。所以这个时候我们才能真正开始初始化参数。

使用 MyInit 实例来进行初始化:

from mxnet import init, nd
from mxnet.gluon import nn

class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        # 实际的初始化逻辑在此省略了。

net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))

net.initialize(init=MyInit())
# 注意到 MyInit 在调用时会打印信息,但当前我们并没有看到相应的日志。下面我们执行前向计算。

x = nd.random.uniform(shape=(2,20))
y = net(x)
# output
Init dense0_weight (256, 20)
Init dense1_weight (10, 256)

我们将这个系统将真正的参数初始化延后到获得了足够信息到时候称之为延后初始化。它可以让模型创建更加简单,因为我们只需要定义每个层的输出大小,而不用去推测它们的的输入大小。这个对于之后将介绍的多达数十甚至数百层的网络尤其有用。

当然延后初始化也可能会造成一定的困解。在调用第一次前向计算之前我们无法直接操作模型参数。例如无法使用 data 和 set_data 函数来获取和改写参数。所以经常我们会额外调用一次 net(x) 来是的参数被真正的初始化。

避免延后初始化

当系统在调用 initialize 函数时能够知道所有参数形状,那么延后初始化就不会发生。我们这里给两个这样的情况。

第一个是模型已经被初始化过,而且我们要对模型进行重新初始化时。因为我们知道参数大小不会变,所以能够立即进行重新初始化。

net.initialize(init=MyInit(), force_reinit=True)

第二种情况是我们在创建层到时候指定了每个层的输入大小,使得系统不需要额外的信息来推测参数形状。下例中我们通过 in_units 来指定每个全连接层的输入大小,使得初始化能够立即进行。

net = nn.Sequential()
net.add(nn.Dense(256, in_units=20, activation='relu'))
net.add(nn.Dense(10, in_units=256))

net.initialize(init=MyInit())
点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
5个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Stella981 Stella981
3年前
KVM调整cpu和内存
一.修改kvm虚拟机的配置1、virsheditcentos7找到“memory”和“vcpu”标签,将<namecentos7</name<uuid2220a6d1a36a4fbb8523e078b3dfe795</uuid
Wesley13 Wesley13
3年前
mysql设置时区
mysql设置时区mysql\_query("SETtime\_zone'8:00'")ordie('时区设置失败,请联系管理员!');中国在东8区所以加8方法二:selectcount(user\_id)asdevice,CONVERT\_TZ(FROM\_UNIXTIME(reg\_time),'08:00','0
Wesley13 Wesley13
3年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
Stella981 Stella981
3年前
Django中Admin中的一些参数配置
设置在列表中显示的字段,id为django模型默认的主键list_display('id','name','sex','profession','email','qq','phone','status','create_time')设置在列表可编辑字段list_editable
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
11个月前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这