07.RNN模型简介(传统RNN、LSTM、GRU)
2026-01-03 15:25:04 拉莫斯世界杯
LSTM 模型
学习目标
了解 LSTM 内部结构及计算公式
掌握 Pytorch 中 LSTM 工具的使用
了解 LSTM 的优势与缺点
LSTM 基本概念
LSTM(Long Short-Term Memory,长短时记忆结构)是传统 RNN 的变体,核心优势是能有效捕捉长序列的语义关联,缓解梯度消失或爆炸现象。其结构更复杂,核心可分为四部分:遗忘门、输入门、细胞状态、输出门。
LSTM 内部结构及计算公式
1. 遗忘门
遗忘门结构分析:与传统RNN的内部结构计算非常相似,首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接,得到[x(t),h(t-1)],然后通过一个全连接层做变换,最后通过sigmoid函数进行激活得到f(t),我们可以将f(t)看作是门值,好比一扇门开合的大小程度,门值都将作用在通过该扇门的张量,遗忘门门值将作用的上一层的细胞状态上,代表遗忘过去的多少信息,又因为遗忘门门值是由x(t),h(t-1)计算得来的,因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息。
类比解释(先建立直觉)
想象你在决定要忘记多少童年记忆:
σ函数:相当于你的"遗忘开关",但它自己不能做决定,只是把决定压缩到0-1之间(1=完全记住,0=完全忘记)
W_f·[h_{t-1}, x_t] + b_f:这才是真正的"决策委员会":
x_t(当前输入):相当于你现在看到的新信息(如看到老照片)
h_{t-1}(上一隐藏状态):相当于你当前的记忆状态(如最近常回忆过去)
权重W_f和偏置b_f:相当于你的性格倾向(天生健忘还是念旧)
σ只是执行者,而[h_{t-1}, x_t]才是真正的决策依据!
结构分析:
输入:前一时刻隐含状态ht−1 + 当前输入xt(拼接)
通过全连接层 + sigmoid激活
输出门值ft(范围[0,1])
功能:
决定遗忘多少上一细胞状态Ct−1的信息
门值大小反映信息保留程度
激活函数 sigmoid 作用:将值压缩在 0~1 之间,调节信息通过量
2. 输入门
输入门结构分析:我们看到输入门的计算公式有两个,第一个就是产生输入门门值的公式,它和遗忘门公式几乎相同,区别只是在于它们之后要作用的目标上。这个公式意味着输入信息有多少需要进行过滤。输入门的第二个公式是与传统RNN的内部结构计算相同。对于LSTM来讲,它得到的是当前的细胞状态,而不是像经典RNN一样得到的是隐含状态。
计算公式:
门值:
候选细胞状态:
结构分析:
门值计算与遗忘门结构相同(不同参数)
候选状态计算类似传统RNN(使用tanh激活)
功能:
筛选当前输入中有价值的信息
与遗忘门协同更新细胞状态
3. 细胞状态更新
细胞状态更新分析:细胞更新的结构与计算公式非常容易理解,这里没有全连接层,只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘,再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果.最终得到更新后的C(t)作为下一个时间步输入的一部分.整个细胞状态更新过程就是对遗忘门和输入门的应用.
更新逻辑:
遗忘部分旧状态(ft⊙Ct−1)
添加部分新信息
特点:
线性操作(无激活函数)
实现长期记忆的保留与更新
更新分析:
无全连接层,直接通过遗忘门和输入门作用于历史状态和新信息
本质是 "遗忘部分历史信息 + 保留部分新信息" 的过程,更新后的Ct作为下一时间步输入
4. 输出门
输出门结构分析:输出门部分的公式也是两个,第一个即是计算输出门的门值,它和遗忘门,输入门计算方式相同。第二个即是使用这个门值产生隐含状态h(t),他将作用在更新后的细胞状态C(t)上,并做tanh激活,最终得到h(t)作为下一时间步输入的一部分,整个输出门的过程,就是为了产生隐含状态h(t)。
计算公式:
门值:
隐状态:
结构分析:
门值ot(范围 0~1)控制细胞状态的输出比例
隐状态ht由门值与tanh(Ct)(范围 - 1~1)相乘得到,作为下一时间步输入
核心作用是生成当前时间步的隐状态ht
功能:
控制当前细胞状态的暴露程度
生成最终隐含状态供下一时间步使用
Bi-LSTM(双向 LSTM)
原理:
不改变 LSTM 内部结构
同时运行正向和反向LSTM
拼接两次计算的结果作为最终输出
Bi-LSTM结构分析:我们看到图中对“我爱中国”这句话或者叫这个输入序列,进行了从左到右和从右到左两次LSTM处理,将得到的结果张量进行了拼接作为最终输出。这种结构能够捕捉语言语法中一些特定的前置或后置特征,增强语义关联,但是模型参数和计算复杂度也随之增加了一倍,一般需要对语料和计算资源进行评估后决定是否使用该结构。
优势:
捕捉前后文语境特征
增强语义关联理解
代价:
参数量和计算量翻倍
需根据语料和算力评估计算资源是否充足
Pytorch 中 LSTM 的使用
LSTM初始化参数(创建模型时配置)
lstm = nn.LSTM( input_size=5, hidden_size=6, num_layers=2, bidirectional=False )
参数名
作用说明
示例值
注意事项
input_size
输入数据的特征维度(如词向量维度)
5
必须与输入数据最后一维一致
hidden_size
隐藏状态的维度(记忆容量)
6
决定模型记忆能力大小
num_layers
堆叠的LSTM层数
2
层数越多模型越复杂
bidirectional
是否为双向LSTM
False
双向时会加倍输出维度
前向传播参数(调用模型时输入)
output, (hn, cn) = lstm(input, (h0, c0))
输入参数:
参数
形状说明
示例形状
作用
input
(seq_len, batch_size, input_size)
(10, 3, 5)
输入序列(10步,batch=3)
h0
(num_layers*方向数, batch, hidden_size)
(2, 3, 6)
初始隐藏状态(如全零初始化)
c0
同h0形状
(2, 3, 6)
初始细胞状态
输出参数:
输出项
形状说明
示例输出形状
含意
output
(seq_len, batch, hidden_size*方向数)
(10, 3, 6)
所有时间步的隐藏状态
hn
同h0形状
(2, 3, 6)
最后一个时间步的隐藏状态
cn
同c0形状
(2, 3, 6)
最后一个时间步的细胞状态
关键概念图解
代码示例
import torch.nn as nn
import torch
# 定义LSTM:input_size=5,hidden_size=6,num_layers=2
rnn = nn.LSTM(5, 6, 2)
# 输入张量:sequence_length=1,batch_size=3,input_size=5
input = torch.randn(1, 3, 5)
# 初始化隐状态和细胞状态:num_layers×num_directions=2,batch_size=3,hidden_size=6
h0 = torch.randn(2, 3, 6)
c0 = torch.randn(2, 3, 6)
# 前向传播
output, (hn, cn) = rnn(input, (h0, c0))
# 输出结果
print("output:", output)
print("最后隐状态hn:", hn)
print("最后细胞状态cn:", cn)
查看打印结果
output:
tensor([[[-0.5387, 0.3468, 0.4146, 0.2536, -0.2615, 0.0667],
[-0.0072, -0.1814, 0.0936, 0.1356, -0.4559, 0.1720],
[-0.2224, 0.2977, 0.4252, -0.2934, 0.0819, 0.1203]]],
grad_fn=
最后隐状态hn:
tensor([[[-0.1427, -0.2920, -0.1963, -0.1874, -0.0126, -0.5496],
[ 0.0570, 0.1978, -0.1338, 0.1485, 0.1445, -0.4599],
[ 0.3210, -0.0794, -0.1845, 0.0285, 0.3966, 0.1899]],
[[-0.5387, 0.3468, 0.4146, 0.2536, -0.2615, 0.0667],
[-0.0072, -0.1814, 0.0936, 0.1356, -0.4559, 0.1720],
[-0.2224, 0.2977, 0.4252, -0.2934, 0.0819, 0.1203]]],
grad_fn=
最后细胞状态cn:
tensor([[[-0.2624, -0.5144, -0.5639, -0.5489, -0.0518, -0.8765],
[ 0.6077, 0.2968, -0.5591, 0.2557, 0.2518, -0.9615],
[ 1.1608, -0.5608, -0.4439, 0.0528, 0.6316, 0.4187]],
[[-0.8019, 0.8717, 0.7672, 0.5314, -0.5260, 0.0815],
[-0.0245, -0.7748, 0.3983, 0.3183, -0.9727, 0.3557],
[-0.4340, 0.4418, 0.7524, -0.8154, 0.2270, 0.1626]]],
grad_fn=
关键细节补充
激活函数选择:
门控使用sigmoid(输出[0,1])
状态计算使用tanh(输出[-1,1])
细胞状态特点:
贯穿整个时间序列
通过线性变换实现信息传递
现代变体:
Peephole LSTM(增加细胞状态到门控的连接)
GRU(简化版LSTM,合并门控)
LSTM 的优缺点
优势:通过门结构有效减缓长序列中的梯度消失 / 爆炸,在长序列任务上表现优于传统 RNN
缺点:内部结构复杂,同等算力下训练效率低于传统 RNN