简单的GRU实例代码

风花雪月
• 阅读 309
import numpy as np
# 定义sigmoid函数
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
class RNN:
    def __init__(self, input_size, hidden_size, output_size):
        # 设定超参数
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # 初始化权重和偏置
        self.Wxh = np.random.randn(hidden_size, input_size) * 0.01  # 输入到隐藏层的权重
        self.Whh = np.random.randn(hidden_size, hidden_size) * 0.01  # 隐藏层到隐藏层的权重
        self.Why = np.random.randn(output_size, hidden_size) * 0.01  # 隐藏层到输出层的权重
        self.bh = np.zeros((hidden_size, 1))  # 隐藏层的偏置
        self.by = np.zeros((output_size, 1))  # 输出层的偏置

    def forward(self, inputs):
        # 初始化隐藏状态和输出
        self.h = np.zeros((self.hidden_size, 1))
        self.outputs = []
        for x in inputs:
            # 更新隐藏状态
            self.h = np.tanh(np.dot(self.Wxh, x) + np.dot(self.Whh, self.h) + self.bh)
            # 计算输出
            y = np.dot(self.Why, self.h) + self.by
            # 应用sigmoid激活函数
            output = sigmoid(y)
            self.outputs.append(output)
        return self.outputs

    def backward(self, inputs, targets, learning_rate=0.1):
        # 初始化梯度
        dWxh = np.zeros_like(self.Wxh)
        dWhh = np.zeros_like(self.Whh)
        dWhy = np.zeros_like(self.Why)
        dbh = np.zeros_like(self.bh)
        dby = np.zeros_like(self.by)
        dh_next = np.zeros_like(self.h)

        for i in reversed(range(len(inputs))):
            # 计算输出误差
            dy = self.outputs[i] - targets[i]
            # 计算输出层的梯度
            dWhy += np.dot(dy, self.h.T)
            dby += dy
            # 计算隐藏层的误差
            dh = np.dot(self.Why.T, dy) + dh_next
            # 应用tanh的导数
            dh_raw = (1 - self.h ** 2) * dh
            # 计算隐藏层的梯度
            dWxh += np.dot(dh_raw, inputs[i].T)
            dWhh += np.dot(dh_raw, self.h.T)
            dbh += dh_raw
            # 更新dh_next
            dh_next = np.dot(self.Whh.T, dh_raw)

        for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
            np.clip(dparam, -5, 5, out=dparam)  # 防止梯度爆炸

        # 更新权重和偏置
        self.Wxh -= learning_rate * dWxh
        self.Whh -= learning_rate * dWhh
        self.Why -= learning_rate * dWhy
        self.bh -= learning_rate * dbh
        self.by -= learning_rate * dby


# 测试代码
# 定义数据和标签
inputs = [np.array([[1], [0], [1]]), np.array([[0], [1], [0]])]
targets = [np.array([[1]]), np.array([[0]])]

input_size = inputs[0].shape[0]
hidden_size = 25
output_size = targets[0].shape[0]

# 创建RNN模型,并进行训练
rnn = RNN(input_size, hidden_size, output_size)
for epoch in range(1000):
    outputs = rnn.forward(inputs)
    loss = np.mean((np.array(outputs)-np.array(targets)) ** 2)
    rnn.backward(inputs, targets)
    if (epoch + 1) % 100 == 0:
        print("次数:", epoch + 1, "误差:", loss)

# 在新数据上进行预测
new_input = np.array([[1], [1], [1]])
output = rnn.forward([new_input])
print("输入:", new_input.flatten())
print("输出:", output)
点赞
收藏
评论区
推荐文章
Irene181 Irene181
3年前
一篇文章带教会你Python访问限制那些事儿
一、前言在Class内部,可以有属性和方法,而外部代码可以通过直接调用实例变量的方法来操作数据,这样,就隐藏了内部的复杂逻辑。二、案例分析以Teacher类的定义来看,外部代码还是可以自由地修改一个实例的name、score属性。classTeacher(object):definit(self,name,score):s
Wesley13 Wesley13
3年前
UIWebView长按保存图片和识别图片二维码的实现方案(使用缓存)
0x00需求:长按识别UIWebView中的二维码,如下图长按识别二维码0x01方案1:给UIWebView增加一个长按手势,激活长按手势时获取当前UIWebView的截图,分析是否包含二维码。核心代码:略优点:流程简单,可以快速实现。不足:无法实现保存UIWebView中图片,如果当前We
Stella981 Stella981
3年前
MacOS VSCode 安装 GO 插件失败问题解决
0x00问题重现Installinggolang.org/x/tools/cmd/guruFAILEDInstallinggolang.org/x/tools/cmd/gorenameFAILEDInstallinggolang.org/x/lint/golintFAILEDInst
Stella981 Stella981
3年前
Scapy 从入门到放弃
0x00前言最近闲的没事,抽空了解下地表最强的嗅探和收发包的工具:scapy。scapy是一个python模块,使用简单,并且能灵活地构造各种数据包,是进行网络安全审计的好帮手。0x01安装因为2020年python官方便不再支持python2,所以使用python3安装。!(https://oscimg.oschina.net/os
Wesley13 Wesley13
3年前
FLV文件格式
1.        FLV文件对齐方式FLV文件以大端对齐方式存放多字节整型。如存放数字无符号16位的数字300(0x012C),那么在FLV文件中存放的顺序是:|0x01|0x2C|。如果是无符号32位数字300(0x0000012C),那么在FLV文件中的存放顺序是:|0x00|0x00|0x00|0x01|0x2C。2.  
Stella981 Stella981
3年前
SpringBoot整合Redis乱码原因及解决方案
问题描述:springboot使用springdataredis存储数据时乱码rediskey/value出现\\xAC\\xED\\x00\\x05t\\x00\\x05问题分析:查看RedisTemplate类!(https://oscimg.oschina.net/oscnet/0a85565fa
Stella981 Stella981
3年前
ELK学习笔记之配置logstash消费kafka多个topic并分别生成索引
0x00 filebeat配置多个topicfilebeat.prospectors:input_type:logencoding:GB2312fields_under_root:truefields:添加字段
Easter79 Easter79
3年前
SpringBoot整合Redis乱码原因及解决方案
问题描述:springboot使用springdataredis存储数据时乱码rediskey/value出现\\xAC\\xED\\x00\\x05t\\x00\\x05问题分析:查看RedisTemplate类!(https://oscimg.oschina.net/oscnet/0a85565fa
Stella981 Stella981
3年前
JavaScript常用函数
1\.字符串长度截取functioncutstr(str,len){vartemp,icount0,patrn/^\x00\xff/,strre"";for(vari
Stella981 Stella981
3年前
JVM 字节码指令表
字节码助记符指令含义0x00nop什么都不做0x01aconst\_null将null推送至栈顶0x02iconst\_m1将int型1推送至栈顶0x03iconst\_0将int型0推送至栈顶0x04iconst\_1将int型1推送至栈顶0x05ic