ND4J自动微分

Wesley13
• 阅读 476

一、前言

    ND4J从beta2开始就开始支持自动微分,不过直到beta4版本为止,自动微分还只支持CPU,GPU版本将在后续版本中实现。

    本篇博客中,我们将用ND4J来构建一个函数,利用ND4J SameDiff构建函数求函数值和求函数每个变量的偏微分值。

二、构建函数

    构建函数和分别手动求偏导数

    ND4J自动微分

    给定一个点(2,3)手动求函数值和偏导,计算如下:

    f=2+3*4+3=17,f对x的偏导:1+2*2*3=13,f对y的偏导:4+1=5

三、通过ND4J自动微分来求

    完整代码

package org.nd4j.samediff;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.factory.Nd4j;

/**
 * 
 * x+y*x2+y
 *
 */
public class Function {

    public static void main(String[] args) {
        //构建SameDiff实例
        SameDiff sd=SameDiff.create();
        //创建变量x、y
        SDVariable x= sd.var("x");
        SDVariable y=sd.var("y");
        
        //定义函数
        SDVariable f=x.add(y.mul(sd.math().pow(x, 2)));
        f.add("addY",y);
        
        //给变量x、y绑定具体值
        x.setArray(Nd4j.create(new double[]{2}));
        y.setArray(Nd4j.create(new double[]{3}));
        //前向计算函数的值
        System.out.println(sd.exec(null, "addY").get("addY"));
        //后向计算求梯度
        sd.execBackwards(null);
        //打印x在(2,3)处的导数
        System.out.println(sd.getGradForVariable("x").getArr());
        //x.getGradient().getArr()和sd.getGradForVariable("x").getArr()等效
        System.out.println(x.getGradient().getArr());
        //打印y在(2,3)处的导数
        System.out.println(sd.getGradForVariable("y").getArr());
    }
}

    四、运行结果

o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend
o.n.n.NativeOpsHolder - Number of threads used for NativeOps: 4
o.n.n.Nd4jBlas - Number of threads used for BLAS: 4
o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 10]
o.n.l.a.o.e.DefaultOpExecutioner - Cores: [8]; Memory: [3.2GB];
o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [MKL]
17.0000
o.n.a.s.SameDiff - Inferring output "addY" as loss variable as none were previously set. Use SameDiff.setLossVariables() to override
13.0000
13.0000
5.0000

    结果为17、13、5和手动求出的结果完全一致。

    自动微分屏蔽了deeplearning在求微分过程中的很多细节,特别是矩阵求导、矩阵范数求导等等,是非常麻烦的,用自动微分,可以轻松实现各式各样的网络结构。

快乐源于分享。

此博客乃作者原创, 转载请注明出处

点赞
收藏
评论区
推荐文章
blmius blmius
3年前
MySQL:[Err] 1292 - Incorrect datetime value: ‘0000-00-00 00:00:00‘ for column ‘CREATE_TIME‘ at row 1
文章目录问题用navicat导入数据时,报错:原因这是因为当前的MySQL不支持datetime为0的情况。解决修改sql\mode:sql\mode:SQLMode定义了MySQL应支持的SQL语法、数据校验等,这样可以更容易地在不同的环境中使用MySQL。全局s
皕杰报表之UUID
​在我们用皕杰报表工具设计填报报表时,如何在新增行里自动增加id呢?能新增整数排序id吗?目前可以在新增行里自动增加id,但只能用uuid函数增加UUID编码,不能新增整数排序id。uuid函数说明:获取一个UUID,可以在填报表中用来创建数据ID语法:uuid()或uuid(sep)参数说明:sep布尔值,生成的uuid中是否包含分隔符'',缺省为
待兔 待兔
5个月前
手写Java HashMap源码
HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程HashMap的使用教程22
Stella981 Stella981
3年前
Lua基础(对象)
:和.区别.   stu{id100,name"Tom",age21}成员变量   function stu.toString()成员函数    return stu.id .. stu.name .. stu.age   endprint(stu
Stella981 Stella981
3年前
JS 对象数组Array 根据对象object key的值排序sort,很风骚哦
有个js对象数组varary\{id:1,name:"b"},{id:2,name:"b"}\需求是根据name或者id的值来排序,这里有个风骚的函数函数定义:function keysrt(key,desc) {  return function(a,b){    return desc ? ~~(ak
Stella981 Stella981
3年前
HIVE 时间操作函数
日期函数UNIX时间戳转日期函数: from\_unixtime语法:   from\_unixtime(bigint unixtime\, string format\)返回值: string说明: 转化UNIX时间戳(从19700101 00:00:00 UTC到指定时间的秒数)到当前时区的时间格式举例:hive   selec
Wesley13 Wesley13
3年前
00:Java简单了解
浅谈Java之概述Java是SUN(StanfordUniversityNetwork),斯坦福大学网络公司)1995年推出的一门高级编程语言。Java是一种面向Internet的编程语言。随着Java技术在web方面的不断成熟,已经成为Web应用程序的首选开发语言。Java是简单易学,完全面向对象,安全可靠,与平台无关的编程语言。
混世魔王 混世魔王
1年前
皕杰报表之数据集函数
所谓数据集函数就是与数据集相关,从数据集取数的函数。这些函数不仅可以将数据直接从数据集取出,而且可以将取出的数据分组、求和、求最大值最小值、求第一条数据和最后一条数据、求前n条数据以及对取出的数据进行按段分割,还能对列和记录进行统计。在皕杰报表中共提供了1
Python进阶者 Python进阶者
11个月前
Excel中这日期老是出来00:00:00,怎么用Pandas把这个去除
大家好,我是皮皮。一、前言前几天在Python白银交流群【上海新年人】问了一个Pandas数据筛选的问题。问题如下:这日期老是出来00:00:00,怎么把这个去除。二、实现过程后来【论草莓如何成为冻干莓】给了一个思路和代码如下:pd.toexcel之前把这
小万哥 小万哥
6个月前
NumPy 双曲函数与集合操作详解
NumPy概览:使用numpy.sinh(),numpy.cosh(),numpy.tanh()计算双曲函数;示例包括求弧度值的双曲正弦、余弦。此外,numpy.arcsinh(),numpy.arccosh(),numpy.arctanh()用于求反函数。同时,NumPy提供集合操作如numpy.unique()构建唯一元素数组,numpy.union1d()求并集,numpy.intersect1d()求交集,numpy.setdiff1d()求差集,numpy.setxor1d()求对称差。