Policy Gradient Within MXNet手记

Stella981
• 阅读 603

Preface

从AIC回来后,做了一些Policy Gradient的工作。主要是觉得RL是一个有意思的领域,而深度网络逼近对这个问题提供了良好的可预期的解决方案。本来想找些程序先做个参考,发现都是些打游戏场景;之前看textbook的时候,对Figure 17.1(Artificial Intelligence: A Modern Approach)的例子印象深刻,觉得是个良好的案例,于是用这个场景;另外,做的过程中开始转gluon,在这个例子上试了下,发现还挺顺手的,就把之前的symbol接口全改了。

Code

程序是github上的Task1

Policy Gradient

正式的说明自有大牛阐述,这里只是为了保持结构完整。RL的两种途径分别是_Value-Based_和_Policy-Based_,Policy的好处在于其简洁性,直接将state映射到action

Works

而Ploicy Gradient,在做的过程中,给我一种和以前的label-data learning极为相似的感觉,相当于都是通过试错法来了解数据,只不过label-data的对应结构关系不同。 所以,最开始时,对backward那部分的想法是,通过在最末接入一个_SoftmaxCrossEntropyLoss_,将选出来的_action_作为_label_,通过将_outgrad_作为_reward_的函数,控制优化的方向和幅度。最后的版本里面,Loss_被删除,直接用_sotmax_作激活函数得到_action_的概率分布,然后用具有one_hot型编码的_outgrad_进行_backward(这要感谢_gluon_的功能)。

另一个需要说明的是exploration-exploitation,之前在看一些非正式的介绍时,看到这一段,总会习惯性地认为需要采取类似于遗传算法的概率操作:设定一个阈值,确定是否放弃系统给出的_action_,如果放弃,再随机选一个_action_出来。 将这两种想法结合起来,发现最后没有收敛到理想方向(第一个commit便是)。于是,在后面debug时,Loss_那部分就被砍掉了。但给我感觉,对系统影响最大的应该是_exploration-exploitation_那部分,后面查看Reinforcement Learning: An Introduction第13章,发现应该是进行采样(实际上,如果按照统计学习的惯性思维,也应该是使用采样的样本代替期望)。多亏16年夏天啃了会_PRML(不然什么叫sample都不知道 (⊙﹏⊙)b),用了一个rejection sampling对_softmax_的输出进行采样:

import numpy as np
N=100
net_out = np.array([1,5,3,1])
net_out = 1.*net_out/net_out.sum()
hist=[]
for i in xrange(1000):
    act_list = np.random.randint(0, 4, (N,))
    prob_list = np.random.uniform(size=(N,))
    idx = list(net_out[act_list] > prob_list).index(True) #np.where(net_out[act_list] > prob_list)
    hist.append(act_list[idx])
h=np.histogram(hist, bins=4)[0]
print h*1./h.sum()

Jul-18,2018 最近要做一个负采样的迭代器,手写的采样太慢,扩充为多线程后问题依然严重。又查了下,发现早就有内建函数:1, 2

# method I
import numpy as np
numpy.random.choice(numpy.arange(1, 7), p=[0.1, 0.05, 0.05, 0.2, 0.4, 0.2])

# method II
from scipy.stats import rv_discrete
numbers = (1,2,3)
distribution = (1./6, 2./6, 3./6)
random_variable = rv_discrete(values=(numbers,distribution))
random_variable.rvs(size=10)

速度蹭蹭的!


Result

后面把结果打出来,发现收敛得还行,对比了一下_R=-0.04_时的结果,textbook上的结果是这样的:

Figure 1. Ground Truth
系统的决策结果:

1

2

3

4

+1

Wall

-1

Tab 1. Net Prediction

可以看到,存在一些出入,但这个结果和*-0.4278<R<-0.0850_的情况下的ground truth 是相同的,这似乎就是说的_Policy Gradient*容易陷入local optima的情况。


2018.3.25 局部最优的情况参考 Sutton Reinforcement Learning: An Introduction中提及的蒙特卡洛方法导致的方差较大问题。

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
待兔 待兔
6个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Jacquelyn38 Jacquelyn38
3年前
2020年前端实用代码段,为你的工作保驾护航
有空的时候,自己总结了几个代码段,在开发中也经常使用,谢谢。1、使用解构获取json数据let jsonData  id: 1,status: "OK",data: 'a', 'b';let  id, status, data: number   jsonData;console.log(id, status, number )
Souleigh ✨ Souleigh ✨
3年前
前端性能优化 - 雅虎军规
无论是在工作中,还是在面试中,web前端性能的优化都是很重要的,那么我们进行优化需要从哪些方面入手呢?可以遵循雅虎的前端优化35条军规,这样对于优化有一个比较清晰的方向.35条军规1.尽量减少HTTP请求个数——须权衡2.使用CDN(内容分发网络)3.为文件头指定Expires或CacheControl,使内容具有缓存性。4.避免空的
Stella981 Stella981
3年前
Android So动态加载 优雅实现与原理分析
背景:漫品Android客户端集成适配转换功能(基于目标识别(So库35M)和人脸识别库(5M)),导致apk体积50M左右,为优化客户端体验,决定实现So文件动态加载.!(https://oscimg.oschina.net/oscnet/00d1ff90e4b34869664fef59e3ec3fdd20b.png)点击上方“蓝字”关注我
Easter79 Easter79
3年前
Twitter的分布式自增ID算法snowflake (Java版)
概述分布式系统中,有一些需要使用全局唯一ID的场景,这种时候为了防止ID冲突可以使用36位的UUID,但是UUID有一些缺点,首先他相对比较长,另外UUID一般是无序的。有些时候我们希望能使用一种简单一些的ID,并且希望ID能够按照时间有序生成。而twitter的snowflake解决了这种需求,最初Twitter把存储系统从MySQL迁移
Wesley13 Wesley13
3年前
35岁,真的是程序员的一道坎吗?
“程序员35岁是道坎”,“程序员35岁被裁”……这些话咱们可能都听腻了,但每当触及还是会感到丝丝焦虑,毕竟每个人都会到35岁。而国内互联网环境确实对35岁以上的程序员不太友好:薪资要得高,却不如年轻人加班猛;虽说经验丰富,但大部分公司并不需要太资深的程序员。但35岁危机并不是不可避免的,比如你可以不断精进技术,将来做技术管理或者
Wesley13 Wesley13
3年前
35岁是技术人的天花板吗?
35岁是技术人的天花板吗?我非常不认同“35岁现象”,人类没有那么脆弱,人类的智力不会说是35岁之后就停止发展,更不是说35岁之后就没有机会了。马云35岁还在教书,任正非35岁还在工厂上班。为什么技术人员到35岁就应该退役了呢?所以35岁根本就不是一个问题,我今年已经37岁了,我发现我才刚刚找到自己的节奏,刚刚上路。
Wesley13 Wesley13
3年前
MySQL部分从库上面因为大量的临时表tmp_table造成慢查询
背景描述Time:20190124T00:08:14.70572408:00User@Host:@Id:Schema:sentrymetaLast_errno:0Killed:0Query_time:0.315758Lock_
Python进阶者 Python进阶者
1年前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这