07.RNN模型简介(传统RNN、LSTM、GRU)

2026-01-03 15:25:04 拉莫斯世界杯

LSTM 模型

学习目标

了解 LSTM 内部结构及计算公式

掌握 Pytorch 中 LSTM 工具的使用

了解 LSTM 的优势与缺点

LSTM 基本概念

LSTM(Long Short-Term Memory,长短时记忆结构)是传统 RNN 的变体,核心优势是能有效捕捉长序列的语义关联,缓解梯度消失或爆炸现象。其结构更复杂,核心可分为四部分:遗忘门、输入门、细胞状态、输出门。

LSTM 内部结构及计算公式

1. 遗忘门

遗忘门结构分析:与传统RNN的内部结构计算非常相似,首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接,得到[x(t),h(t-1)],然后通过一个全连接层做变换,最后通过sigmoid函数进行激活得到f(t),我们可以将f(t)看作是门值,好比一扇门开合的大小程度,门值都将作用在通过该扇门的张量,遗忘门门值将作用的上一层的细胞状态上,代表遗忘过去的多少信息,又因为遗忘门门值是由x(t),h(t-1)计算得来的,因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息。

类比解释(先建立直觉)

想象你在决定要忘记多少童年记忆:

​​σ函数​​:相当于你的"遗忘开关",但​​它自己不能做决定​​,只是把决定压缩到0-1之间(1=完全记住,0=完全忘记)

​​W_f·[h_{t-1}, x_t] + b_f​​:这才是真正的"决策委员会":

x_t(当前输入):相当于​​你现在看到的新信息​​(如看到老照片)

h_{t-1}(上一隐藏状态):相当于​​你当前的记忆状态​​(如最近常回忆过去)

权重W_f和偏置b_f:相当于​​你的性格倾向​​(天生健忘还是念旧)

​​σ只是执行者​​,而[h_{t-1}, x_t]才是真正的决策依据!

结构分析​​:

输入:前一时刻隐含状态ht−1​ + 当前输入xt​(拼接)

通过全连接层 + sigmoid激活

输出门值ft​(范围[0,1])

​​功能​​:

决定遗忘多少上一细胞状态Ct−1​的信息

门值大小反映信息保留程度

激活函数 sigmoid 作用:将值压缩在 0~1 之间,调节信息通过量

2. 输入门

输入门结构分析:我们看到输入门的计算公式有两个,第一个就是产生输入门门值的公式,它和遗忘门公式几乎相同,区别只是在于它们之后要作用的目标上。这个公式意味着输入信息有多少需要进行过滤。输入门的第二个公式是与传统RNN的内部结构计算相同。对于LSTM来讲,它得到的是当前的细胞状态,而不是像经典RNN一样得到的是隐含状态。

计算公式:

门值:

候选细胞状态:

结构分析​​:

门值计算与遗忘门结构相同(不同参数)

候选状态计算类似传统RNN(使用tanh激活)

​​功能​​:

筛选当前输入中有价值的信息

与遗忘门协同更新细胞状态

3. 细胞状态更新

细胞状态更新分析:细胞更新的结构与计算公式非常容易理解,这里没有全连接层,只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘,再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果.最终得到更新后的C(t)作为下一个时间步输入的一部分.整个细胞状态更新过程就是对遗忘门和输入门的应用.

更新逻辑​​:

遗忘部分旧状态(ft​⊙Ct−1​)

添加部分新信息

​​特点​​:

线性操作(无激活函数)

实现长期记忆的保留与更新

更新分析:

无全连接层,直接通过遗忘门和输入门作用于历史状态和新信息

本质是 "遗忘部分历史信息 + 保留部分新信息" 的过程,更新后的Ct作为下一时间步输入

4. 输出门

输出门结构分析:输出门部分的公式也是两个,第一个即是计算输出门的门值,它和遗忘门,输入门计算方式相同。第二个即是使用这个门值产生隐含状态h(t),他将作用在更新后的细胞状态C(t)上,并做tanh激活,最终得到h(t)作为下一时间步输入的一部分,整个输出门的过程,就是为了产生隐含状态h(t)。

计算公式:

门值:

隐状态:

结构分析:

门值ot(范围 0~1)控制细胞状态的输出比例

隐状态ht由门值与tanh(Ct)(范围 - 1~1)相乘得到,作为下一时间步输入

核心作用是生成当前时间步的隐状态ht

​​功能​​:

控制当前细胞状态的暴露程度

生成最终隐含状态供下一时间步使用

Bi-LSTM(双向 LSTM)

原理​​:

不改变 LSTM 内部结构

同时运行正向和反向LSTM

拼接两次计算的结果作为最终输出

Bi-LSTM结构分析:我们看到图中对“我爱中国”这句话或者叫这个输入序列,进行了从左到右和从右到左两次LSTM处理,将得到的结果张量进行了拼接作为最终输出。这种结构能够捕捉语言语法中一些特定的前置或后置特征,增强语义关联,但是模型参数和计算复杂度也随之增加了一倍,一般需要对语料和计算资源进行评估后决定是否使用该结构。

优势​​:

捕捉前后文语境特征

增强语义关联理解

​​代价​​:

参数量和计算量翻倍

需根据语料和算力评估计算资源是否充足

Pytorch 中 LSTM 的使用

LSTM初始化参数(创建模型时配置)​​

lstm = nn.LSTM( input_size=5, hidden_size=6, num_layers=2, bidirectional=False )

参数名

作用说明

示例值

注意事项

input_size

输入数据的特征维度(如词向量维度)

5

必须与输入数据最后一维一致

hidden_size

隐藏状态的维度(记忆容量)

6

决定模型记忆能力大小

num_layers

堆叠的LSTM层数

2

层数越多模型越复杂

bidirectional

是否为双向LSTM

False

双向时会加倍输出维度

​​前向传播参数(调用模型时输入)​​

output, (hn, cn) = lstm(input, (h0, c0))

输入参数:

参数

形状说明

示例形状

作用

input

(seq_len, batch_size, input_size)

(10, 3, 5)

输入序列(10步,batch=3)

h0

(num_layers*方向数, batch, hidden_size)

(2, 3, 6)

初始隐藏状态(如全零初始化)

c0

同h0形状

(2, 3, 6)

初始细胞状态

输出参数:

输出项

形状说明

示例输出形状

含意

output

(seq_len, batch, hidden_size*方向数)

(10, 3, 6)

所有时间步的隐藏状态

hn

同h0形状

(2, 3, 6)

最后一个时间步的隐藏状态

cn

同c0形状

(2, 3, 6)

最后一个时间步的细胞状态

​​关键概念图解​

代码示例

import torch.nn as nn

import torch

# 定义LSTM:input_size=5,hidden_size=6,num_layers=2

rnn = nn.LSTM(5, 6, 2)

# 输入张量:sequence_length=1,batch_size=3,input_size=5

input = torch.randn(1, 3, 5)

# 初始化隐状态和细胞状态:num_layers×num_directions=2,batch_size=3,hidden_size=6

h0 = torch.randn(2, 3, 6)

c0 = torch.randn(2, 3, 6)

# 前向传播

output, (hn, cn) = rnn(input, (h0, c0))

# 输出结果

print("output:", output)

print("最后隐状态hn:", hn)

print("最后细胞状态cn:", cn)

查看打印结果

output:

tensor([[[-0.5387, 0.3468, 0.4146, 0.2536, -0.2615, 0.0667],

[-0.0072, -0.1814, 0.0936, 0.1356, -0.4559, 0.1720],

[-0.2224, 0.2977, 0.4252, -0.2934, 0.0819, 0.1203]]],

grad_fn=)

最后隐状态hn:

tensor([[[-0.1427, -0.2920, -0.1963, -0.1874, -0.0126, -0.5496],

[ 0.0570, 0.1978, -0.1338, 0.1485, 0.1445, -0.4599],

[ 0.3210, -0.0794, -0.1845, 0.0285, 0.3966, 0.1899]],

[[-0.5387, 0.3468, 0.4146, 0.2536, -0.2615, 0.0667],

[-0.0072, -0.1814, 0.0936, 0.1356, -0.4559, 0.1720],

[-0.2224, 0.2977, 0.4252, -0.2934, 0.0819, 0.1203]]],

grad_fn=)

最后细胞状态cn:

tensor([[[-0.2624, -0.5144, -0.5639, -0.5489, -0.0518, -0.8765],

[ 0.6077, 0.2968, -0.5591, 0.2557, 0.2518, -0.9615],

[ 1.1608, -0.5608, -0.4439, 0.0528, 0.6316, 0.4187]],

[[-0.8019, 0.8717, 0.7672, 0.5314, -0.5260, 0.0815],

[-0.0245, -0.7748, 0.3983, 0.3183, -0.9727, 0.3557],

[-0.4340, 0.4418, 0.7524, -0.8154, 0.2270, 0.1626]]],

grad_fn=)

关键细节补充

​​激活函数选择​​:

门控使用sigmoid(输出[0,1])

状态计算使用tanh(输出[-1,1])

​​细胞状态特点​​:

贯穿整个时间序列

通过线性变换实现信息传递

​​现代变体​​:

Peephole LSTM(增加细胞状态到门控的连接)

GRU(简化版LSTM,合并门控)

LSTM 的优缺点

优势:通过门结构有效减缓长序列中的梯度消失 / 爆炸,在长序列任务上表现优于传统 RNN

缺点:内部结构复杂,同等算力下训练效率低于传统 RNN