LSTM的变体

一、GRU

1、什么是GRU

门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了LSTM的结构,将遗忘门和输入门合并为一个更新门,并增加了一个重置门,同时合并了单元状态和隐藏状态,使得模型更加简洁,训练速度更快,且在性能上与LSTM相当。

2、GRU的核心

核心在于两个门:更新门(update gate)和重置门(reset gate)。更新门控制着从前一时刻的状态信息中保留多少到当前状态,而重置门决定着前一状态有多少信息被写入到当前的候选集中。这种结构使得GRU在处理长序列数据时能够更好地捕捉长期依赖关系,同时减少了模型参数,提高了计算效率。

3、GRU的应用

应用非常广泛,包括但不限于自然语言处理(NLP)、语音识别、图像处理等领域。在NLP领域,GRU可以用于语言建模、情感分析、机器翻译等任务;在语音识别领域,GRU可以用于语音信号的特征提取和识别;在图像处理领域,GRU可以用于图像分类、目标检测等任务。GRU的简洁性和效率使其在处理大规模序列数据时具有优势。

在选择GRU和LSTM时,通常考虑的因素包括任务的复杂性、数据集的大小以及训练资源。由于GRU参数更少,收敛速度更快,因此在需要快速迭代和实验时,GRU通常是首选。然而,在某些需要对复杂序列依赖关系进行建模的任务中,LSTM可能会表现得更好。

总的来说,GRU是一种强大的循环神经网络架构,它通过引入门控机制来控制信息流,有效地解决了传统RNN的梯度消失问题。GRU的简洁性和效率使其在多种序列建模任务中表现出色,成为了深度学习中处理时序数据的重要工具之一。

4、GRU的工作原理

5、手写代码实现

import numpy as np

class GRU():
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 初始化参数w和b
        self.W_z = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)
        self.b_z = np.zeros(self.hidden_size)
        # 重置门
        self.W_r = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)
        self.b_r = np.zeros(self.hidden_size)
        # 候选隐藏状态
        self.W_h = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)
        self.b_h = np.zeros(self.hidden_size)

    def tanh(self, x):
        return np.tanh(x)

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def forward(self, x):
        h_prev = np.zeros((self.hidden_size,))
        concat_input = np.concatenate([x, h_prev], axis=0)

        z_t = self.sigmoid(np.dot(self.W_z, concat_input) + self.b_z)
        r_t = self.sigmoid(np.dot(self.W_r, concat_input) + self.b_r)

        concat_reset_input = np.concatenate([x, r_t * h_prev], axis=0)

        h_hat_t = self.tanh(np.dot(self.W_h, concat_reset_input) + self.b_h)

        h_t = (1 - z_t) * h_prev + z_t * h_hat_t
        return h_t

二、BiLSTM

1、什么是BiSTM

BiSTM,即双向门控循环单元(Bidirectional Gated Recurrent Unit),是一种循环神经网络(RNN)的变体。它结合了前向和后向的GRU,能够同时处理过去和未来的信息,从而更好地捕捉序列数据中的上下文关系。

在BiSTM中,数据通过两个GRU网络进行处理:一个从左到右(前向),另一个从右到左(后向)。这两个网络的输出然后被拼接或相加,形成最终的特征表示,这个特征表示包含了序列的双向信息。这种结构特别适合于需要理解序列中前后文信息的任务,如文本分类、语音识别、命名实体识别(NER)等。

2、BiSTM的关键特点包括:

  1. 双向信息捕捉:BiSTM能够同时考虑序列中每个元素之前的和之后的上下文信息,这使得它在处理像文本这样的序列数据时非常有效,因为文本中词汇的含义往往受到其前后词汇的影响。

  2. 门控机制:BiSTM继承了GRU的门控机制,包括更新门和重置门,这些门控单元可以控制信息的流动,从而减少无效或噪声信息的干扰,并增强模型对重要信息的记忆能力。

  3. 应用广泛:BiSTM因其强大的序列处理能力而被广泛应用于各种领域,包括自然语言处理(NLP)、语音识别、时间序列分析等。

  4. 模型性能:在某些任务中,BiSTM能够提供比单向GRU或LSTM更好的性能,尤其是在需要捕捉长期依赖关系的任务中。

  5. 模型复杂度:由于BiSTM包含两个GRU网络,其模型参数和计算复杂度相对于单向GRU或LSTM会有所增加,但在很多情况下,这种增加是值得的,因为它能带来更准确的预测结果

 3、手写BiLSTM代码

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class LSTM(nn.Module):
    def __init__(self, vocab_size, target_size, input_size=512, hidden_size=512):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, input_size)
        self.mlp = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.lstm = nn.LSTM(hidden_size, hidden_size * 2, num_layers=3, batch_first=True, dropout=0.5)
        self.avg_lstm = nn.AdaptiveAvgPool1d(1)
        self.avg_linear = nn.AdaptiveAvgPool1d(1)
        self.out_linear = nn.Sequential(
            nn.Linear(hidden_size * 2 + hidden_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, target_size)
        )
        self.norm = nn.LayerNorm(hidden_size * 2)

    def forward(self, x, lengths):
        x = self.embedding(x)
        mlp = self.mlp(x)

        pached_embed = pack_padded_sequence(mlp, lengths, batch_first=True, enforce_sorted=False)
        lstm_out, _ = self.lstm(pached_embed)
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)

        lstm_out = self.norm(lstm_out)
        avg_lstm = self.avg_lstm(lstm_out.permute(0, 2, 1)).squeeze(-1)
        avg_linear = self.avg_linear(mlp.permute(0, 2, 1)).squeeze(-1)

        out = torch.cat([avg_lstm, avg_linear], dim=-1)
        return self.out_linear(out)

class BiLSTM(nn.Module):
    def __init__(self, input_size=512, hidden_size=512, output_size=512):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_forward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
        self.lstm_backward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)

    def forward(self, x):
        out_forward, _ = self.lstm_forward(x)
        out_backward, _ = self.lstm_backward(torch.flip(x, dims=[1]))
        out_backward = torch.flip(out_backward, dims=[1])
        combined_output = torch.cat([out_forward, out_backward], dim=-1)
        return combined_output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/889909.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言 | Leetcode C语言题解之第466题统计重复个数

题目&#xff1a; 题解&#xff1a; #include <stdlib.h> #include <stdio.h> #include <stdbool.h> #include <string.h> #include <math.h> #include <limits.h>#define MMAX(a, b) ((a) > (b)? (a) : (b)) #define MMIN(a,…

【项目】五子棋对战测试报告

目录 一、项目背景 二、项目功能 三、测试计划 1、功能测试&#xff1a; &#xff08;1&#xff09;测试用例&#xff1a; &#xff08;2&#xff09;实际执行测试的部分操作/截图 2、自动化测试 3、性能测试 一、项目背景 1、五子棋对战游戏 采用了前后端分离的方法来…

GO网络编程(七):海量用户通信系统5:分层架构

P323开始&#xff08;尚硅谷GO教程&#xff09;老韩又改目录结构了&#xff0c;没办法&#xff0c;和之前一样&#xff0c;先说下目录结构&#xff0c;再给代码&#xff0c;部分代码在之前讲过&#xff0c;还有知识的话由于本人近期很忙&#xff0c;所以这些就不多赘述了&#…

web自动化测试基础(从配置环境到自动化实现登录测试用例的执行,vscode如何导入自己的python包)

接下来的一段时间里我会和大家分享自动化测试相关的一些知识希望大家可以多多支持&#xff0c;一起进步。 一、环境的配置 前提安装好了python解释器并配好了环境&#xff0c;并安装好了VScode 下载的浏览器和浏览器驱动需要一样的版本号(只看大版本)。 1、安装浏览器 Chro…

vue-live2d看板娘集成方案设计使用教程

文章目录 前言v1.1.x版本&#xff1a;vue集成看板娘&#xff08;暂不使用&#xff0c;在v1.2.x已替换&#xff09;集成看板娘实现看板娘拖拽效果方案资源备份存储 当前最新调研&#xff1a;2024.10.2开源方案1&#xff1a;OhMyLive2D&#xff08;推荐&#xff09;开源方案2&…

SpringMVC2~~~

目录 数据格式化 基本数据类型可以和字符串自动转换 特殊数据类型和字符串间的转换 验证及国际化 自定义验证错误信息 细节 数据类型转换校验核心类DataBinder 工作机制 取消某个属性的绑定 中文乱码处理 处理json和HttpMessageConverter 处理Json-ResponseBody 处理…

go开发环境设置-安装与交叉编译(二)

1. 引言 Go语言&#xff0c;又称Golang&#xff0c;是Google开发的一门编程语言&#xff0c;以其高效、简洁和并发编程的优势受到广泛欢迎。作为一门静态类型、编译型语言&#xff0c;Go在构建网络服务器、微服务和命令行工具方面表现突出。 在开发过程中&#xff0c;开发者常…

吸毛效果好的宠物空气净化器分享,希喂、霍尼韦尔、米家实测

说起宠物空气净化器&#xff0c;几年前我可能会一脸鄙夷&#xff1a;为啥要花这种智商税冤枉钱&#xff1f; 直到之前养了一只猫&#xff0c;被家中乱飞的浮毛和滂臭的异味搞到头晕&#xff0c;于是作为i一个养宠的家电测评博主&#xff0c;索性对宠物空气净化器这玩意做了超级…

前端继承:原理、实现方式与应用场景

目录 一、定义 二、语法和实现方式 1.原型链继承 2.构造函数继承 3.组合继承 4.ES6类继承 三、使用方式 四、优点 五、缺点 六、适用场景 一、定义 前端继承是指在面向对象编程中&#xff0c;一个对象可以继承另一个对象的属性和方法。在前端领域&#xff0c;通常是指…

OpenCV高级图形用户界面(1)创建滑动条函数createTrackbar()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 创建一个滑动条并将其附加到指定的窗口。 该函数 createTrackbar 创建一个具有指定名称和范围的滑动条&#xff08;滑块或范围控制&#xff09;…

C语言之扫雷小游戏(完整代码版)

说起扫雷游戏&#xff0c;这应该是很多人童年的回忆吧&#xff0c;中小学电脑课最常玩的必有扫雷游戏&#xff0c;那么大家知道它是如何开发出来的吗&#xff0c;扫雷游戏背后的原理是什么呢&#xff1f;今天就让我们一探究竟&#xff01; 扫雷游戏介绍 如下图&#xff0c;简…

使用3080ti配置安装blip2

使用3080ti运行blip2的案例 本机环境&#xff08;大家主要看GPU&#xff0c;ubuntu版本和cuda版本即可&#xff09;&#xff1a;安装流程我最后安装的所有包的信息&#xff08;python 3.9 &#xff09;以供参考&#xff08;environment.yml&#xff09;&#xff1a; 本机环境&a…

【python实操】python小程序之计算对象个数、游戏更新分数

引言 python小程序之计算对象个数、游戏更新分数 文章目录 引言一、计算对象个数1.1 题目1.2 代码1.3 代码解释1.3.1 代码结构1.3.2 模块解释1.3.3 解释输出 二、游戏更新分数2.1 题目2.2 代码2.3 代码解释2.3.1 定义 Game 类2.3.2 创建 Game 实例并调用方法 三、思考3.1 计算对…

安卓13禁止锁屏 关闭锁屏 android13禁止锁屏 关闭锁屏

总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析4.代码修改5.彩蛋1.前言 设置 =》安全 =》屏幕锁定 =》 无。 我们通过修改系统屏幕锁定配置,来达到设置屏幕不锁屏的配置。像网上好多文章都只写了在哪里改,改什么东西,但是实际上并未写明为什么要改那…

RabbitMQ 高级特性——死信队列

文章目录 前言死信队列什么是死信常见面试题死信队列的概念&#xff1a;死信的来源&#xff08;造成死信的原因有哪些&#xff09;死信队列的应用场景 前言 前面我们学习了为消息和队列设置 TTL 过期时间&#xff0c;这样可以保证消息的积压&#xff0c;那么对于这些过期了的消…

数据结构-4.6.KMP算法(旧版下)-朴素模式匹配算法的优化

一.绪论&#xff1a; 当主串字符和模式串字符不匹配时会执行jnext[j]来改变模式串的指针&#xff0c;但主串的指针不变。 二.求模式串的next数组&#xff1a; 1.例一&#xff1a; 如模式串abcabd&#xff0c;当第六个字符d匹配失败时&#xff0c;此时主串中前五个字符abcab都…

连锁店线下线上一体化收银系统源码

近年来线下线上一体化已经成为很多连锁门店追求的方向。其中&#xff0c;线下门店能够赋予品牌发展的价值依然不可小觑。在线下门店中&#xff0c;收银系统可以说是运营管理的关键工具&#xff0c;好的收银系统能够为品牌门店赋能。对于连锁品牌而言&#xff0c;对收银系统的要…

软媒市场新蓝海:软文媒体自助发布与自助发稿的崛起

在信息时代的浪潮中,软媒市场以其独特的魅力和无限的潜力,成为了企业营销的新宠。随着互联网的飞速发展,软文媒体自助发布平台应运而生,为企业提供了更加高效、便捷的营销方式。而自助发稿功能的加入,更是让软媒市场的蓝海变得更加广阔。 软媒市场的独特价值 软媒市场之所以能…

Android Studio Koala中Kotlin引入序列化Parcelable

找了一堆资料没有新构建序列化的方法&#xff0c;踩坑经历如下&#xff1a; 前提是使用Kotlin创建的项目 之前的build.gradle版本写法如下&#xff1a; 但是新版Android Studio Koala使用序列化模式发生了改变&#xff0c;如下&#xff1a; 测试成功如下&#xff1a; 发出来…

【万字长文】Word2Vec计算详解(三)分层Softmax与负采样

【万字长文】Word2Vec计算详解&#xff08;三&#xff09;分层Softmax与负采样 写在前面 第三部分介绍Word2Vec模型的两种优化方案。 【万字长文】Word2Vec计算详解&#xff08;一&#xff09;CBOW模型 markdown行 9000 【万字长文】Word2Vec计算详解&#xff08;二&#xff0…
最新文章