【自然语言处理(NLP)】Bahdanau 注意力(Bahdanau Attention)原理及代码实现

文章目录

个人主页: 道友老李
欢迎加入社区: 道友老李的学习社区

介绍

**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。

NLP的任务可以分为多个层次,包括但不限于:

  1. 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
  2. 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
  3. 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
  4. 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
  5. 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
  6. 机器翻译:将一种自然语言转换为另一种自然语言。
  7. 问答系统:构建可以回答用户问题的系统。
  8. 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
  9. 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
  10. 语音识别:将人类的语音转换为计算机可读的文字格式。

NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。

Bahdanau 注意力(Bahdanau Attention)

Bahdanau注意力(Bahdanau Attention)是自然语言处理中一种经典的注意力机制。

在传统的编码器 - 解码器架构(如基于RNN的架构)中,编码器将整个输入序列编码为一个固定长度的向量,解码器依赖该向量生成输出。当输入序列较长时,这种固定长度向量难以存储所有重要信息,导致性能下降。Bahdanau注意力机制通过让解码器在生成每个输出时,动态关注输入序列不同部分,解决此问题。

原理

允许解码器在生成输出时,根据当前状态,从编码器的隐藏状态序列中选择性聚焦,获取与当前生成任务最相关信息,而非仅依赖单一固定向量。

Bahdanau 注意力机制中计算上下文向量的公式:

公式含义

  • c t c_t ct 表示在解码器的时间步 t t t 时得到的上下文向量,它综合了编码器隐藏状态序列中的信息,用于辅助解码器在该时间步生成输出。
  • T T T 是编码器的时间步总数,意味着要考虑编码器所有时间步的隐藏状态。
  • α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st−1,hi) 是注意力权重,它表示在解码器时间步 t − 1 t - 1 t−1 的隐藏状态 s t − 1 s_{t - 1} st−1 条件下,对编码器第 i i i 个时间步隐藏状态 h i h_i hi 的关注程度。这个权重是通过一个特定的计算(通常涉及一个小型神经网络来计算相似度等)得到,并经过softmax函数归一化,取值范围在 0 0 0 到 1 1 1 之间,且 ∑ i = 1 T α ( s t − 1 , h i ) = 1 \sum_{i = 1}^{T}\alpha(s_{t - 1}, h_i)=1 ∑i=1Tα(st−1,hi)=1。
  • h i h_i hi 是编码器在第 i i i 个时间步的隐藏状态,它包含了输入序列在该时间步及之前的信息。

计算过程

  • 首先,根据解码器上一个时间步的隐藏状态 s t − 1 s_{t - 1} st−1 和编码器所有时间步的隐藏状态 h i h_i hi( i i i 从 1 1 1 到 T T T),计算出每个 h i h_i hi 对应的注意力权重 α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st−1,hi)。
  • 然后,将这些注意力权重分别与对应的编码器隐藏状态 h i h_i hi 相乘,并对所有时间步的乘积结果进行求和,就得到了当前解码器时间步 t t t 的上下文向量 c t c_t ct。

这个上下文向量 c t c_t ct 后续会与解码器当前时间步 t t t 的隐藏状态等信息结合,用于生成当前时间步的输出,比如在机器翻译任务中预测目标语言的下一个单词。

一个带有Bahdanau注意力的循环神经网络编码器-解码器模型:

编码器部分

  • 嵌入层:将源序列(如源语言句子中的单词)从离散的符号转换为低维、连续的向量表示,即词嵌入,便于模型后续处理,同时捕捉单词语义关系。
  • 循环层:一般由RNN、LSTM或GRU等单元构成。按顺序处理嵌入层输出的向量序列,每个时间步结合当前输入和上一时刻隐藏状态更新隐藏状态,逐步将源序列信息编码到隐藏状态中,最终输出包含源序列语义信息的隐藏状态序列。

注意力机制部分

位于编码器和解码器之间,允许解码器在生成输出时,根据当前状态从编码器的隐藏状态序列中动态选择相关信息。它计算解码器当前隐藏状态与编码器各时间步隐藏状态的相关性,得到注意力权重,对编码器隐藏状态加权求和生成上下文向量,为解码器提供与当前生成任务相关的信息。

解码器部分

  • 嵌入层:与编码器的嵌入层类似,将目标序列(如目标语言句子中的单词)的离散符号转换为向量表示,不过针对目标语言。
  • 循环层:接收编码器输出的隐藏状态序列以及注意力机制生成的上下文向量,结合目标序列嵌入向量,按顺序处理并更新隐藏状态,生成目标序列下一个元素的预测。
  • 全连接层:对循环层输出进行处理,将隐藏状态映射到目标词汇表维度,经softmax函数计算词汇表中每个单词的概率分布,预测当前时间步最可能的输出单词。

该架构在机器翻译、文本摘要等序列到序列任务中应用广泛,注意力机制可有效解决长序列信息处理难题,提升模型性能。

计算过程

  1. 计算注意力分数 :解码器在时间步 t t t的隐藏状态 h t d e c h_t^{dec} htdec作为查询(query),与编码器所有时间步的隐藏状态 h i e n c h_i^{enc} hienc( i = 1 , ⋯   , T i = 1, \cdots, T i=1,⋯,T, T T T为编码器时间步数)计算注意力分数 e t , i e_{t,i} et,i,一般通过一个小型神经网络计算,如 e t , i = a ( h t d e c , h i e n c ) e_{t,i}=a(h_t^{dec}, h_i^{enc}) et,i=a(htdec,hienc), a a a是一个非线性函数。
  2. 归一化注意力分数 :将注意力分数 e t , i e_{t,i} et,i通过softmax函数归一化,得到注意力权重 α t , i \alpha_{t,i} αt,i,即 α t , i = exp ⁡ ( e t , i ) ∑ j = 1 T exp ⁡ ( e t , j ) \alpha_{t,i}=\frac{\exp(e_{t,i})}{\sum_{j = 1}^{T}\exp(e_{t,j})} αt,i=∑j=1Texp(et,j)exp(et,i),表示编码器第 i i i个时间步对解码器当前时间步 t t t的重要程度。
  3. 计算上下文向量 :根据注意力权重对编码器隐藏状态加权求和,得到上下文向量 c t c_t ct, c t = ∑ i = 1 T α t , i h i e n c c_t=\sum_{i = 1}^{T}\alpha_{t,i}h_i^{enc} ct=∑i=1Tαt,ihienc,它包含了与当前生成任务相关的输入信息。
  4. 生成输出 :上下文向量 c t c_t ct与解码器当前隐藏状态 h t d e c h_t^{dec} htdec结合,如拼接后输入到后续网络层,生成当前时间步的输出。

代码实现

导包

python 复制代码
import torch
from torch import nn
import dltools

定义注意力解码器

python 复制代码
class AttentionDecoder(dltools.Decoder):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    @property
    def attention_weights(self):
        raise NotImplementedError

添加Bahdanau的decoder

python 复制代码
class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = dltools.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs : (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # X : (batch_size, num_steps, vocab_size)
        X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens
            # print('query:', query.shape) # 4, 1, 16
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
            # print('context: ', context.shape)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # print('x: ', x.shape)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            # print('out:', out.shape)
            # print('hidden_state:', hidden_state.shape)
            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
            
        # print('---------------------------------')
        outputs = self.dense(torch.cat(outputs, dim=0))
        # print('解码器最终输出形状: ', outputs.shape)
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()

# batch_size 4, num_steps 7
X = torch.zeros((4, 7), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
---------------------------------
解码器最终输出形状:  torch.Size([7, 4, 10])
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

训练

执行训练前,将decoder中的print屏蔽掉!!

python 复制代码
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

评估指标 bleu

python 复制代码
def bleu(pred_seq, label_seq, k):
    print('pred_seq', pred_seq)
    print('label_seq:', label_seq)
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - (len_label / len_pred)))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
            
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *=  math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))   
    return score

开始预测

python 复制代码
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est malade .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000
相关推荐
soonlyai5 分钟前
解决DeepSeek服务器繁忙问题:本地部署与优化方案
服务器·人工智能·经验分享·笔记·微信公众平台·媒体
☆cwlulu17 分钟前
深度学习练手小例子——cifar10数据集分类问题
人工智能·深度学习·分类
逛逛GitHub33 分钟前
斩获 66K 星!这 2 个开源项目绝了。
人工智能·后端·github
王大队长37 分钟前
error: package directory ‘torch/cuda‘ does not exist
人工智能·pytorch·深度学习
追光天使1 小时前
Mac M1 ComfyUI 中 AnyText插件安装问题汇总?
人工智能·pytorch·macos
清同趣科研1 小时前
R分析|稀有or丰富,群落物种六级分类鉴别稀有和丰富物种:Excel中简单实现
人工智能·分类·r语言
说私域1 小时前
开源2 + 1链动模式AI智能名片S2B2C商城小程序视角下从产品经营到会员经营的转型探究
人工智能·小程序·开源·流量运营
一个处女座的程序猿O(∩_∩)O1 小时前
React+AI 技术栈(2025 版)
前端·人工智能·react.js
正在走向自律2 小时前
AI绘画:解锁商业设计新宇宙(6/10)
人工智能·ai作画·ai绘画
htuhxf2 小时前
TfidfVectorizer
python·自然语言处理·nlp·tf-idf·文本特征