【自然语言处理(NLP)】多头注意力(Multi - Head Attention)原理及代码实现

文章目录

个人主页: 道友老李
欢迎加入社区: 道友老李的学习社区

介绍

**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。

NLP的任务可以分为多个层次,包括但不限于:

  1. 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
  2. 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
  3. 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
  4. 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
  5. 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
  6. 机器翻译:将一种自然语言转换为另一种自然语言。
  7. 问答系统:构建可以回答用户问题的系统。
  8. 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
  9. 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
  10. 语音识别:将人类的语音转换为计算机可读的文字格式。

NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。

多头注意力

原理

多头注意力机制(Multi - Head Attention)的结构示意图:

多头注意力机制首先将查询(Query)、键(Key)、值(Value)分别通过多个全连接层进行线性变换,得到多个不同的表示。然后,对这些不同的表示分别进行注意力计算。最后,将各个注意力的结果进行连结(Concatenate),再通过一个全连接层得到最终输出。

这种机制允许模型在不同的表示子空间中并行地关注输入序列的不同部分,能够捕捉到更丰富的语义信息,广泛应用于Transformer等模型架构中,在自然语言处理、计算机视觉等领域有重要应用。

模型计算方式:

在该表达式中, h i h_i hi 是注意力机制计算得到的输出, f f f 一般表示注意力计算函数(如缩放点积注意力等), W q i W_q^i Wqi、 W k i W_k^i Wki、 W v i W_v^i Wvi 分别是针对查询(query)、键(key)、值(value)的可学习权重矩阵, q q q、 k k k、 v v v 分别为查询向量、键向量、值向量 , R n \mathbb{R}^n Rn 表示输出 h i h_i hi 处于 n n n 维实数空间。它表达了在注意力计算中,通过对查询、键、值进行线性变换后再经过注意力计算函数得到输出的过程。

矩阵运算表达式:

表达式中 [ h 1 ⋮ h n ] \begin{bmatrix}h_1\\ \vdots \\h_n\end{bmatrix} h1⋮hn 是一个由 h 1 h_1 h1 到 h n h_n hn 构成的列向量,这些 h i h_i hi 通常可以是注意力机制等模块的输出。 W o W_o Wo 是一个可学习的权重矩阵,其维度为 R p × n \mathbb{R}^{p\times n} Rp×n ,这里 p p p 是输出维度相关参数, n n n 是输入向量的长度(即 h i h_i hi 的数量)。该表达式表示对由 h i h_i hi 组成的向量进行线性变换,常用于深度学习模型(如Transformer等)的后处理阶段,对前面模块输出进行进一步的特征变换或整合。

代码实现

导包

python 复制代码
import math
import torch
from torch import nn
import dltools

多头注意力结构

python 复制代码
class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = dltools.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        # queries, keys, values 传入的形状: (batch_size, 查询熟练或者键值对数量, num_hiddens)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
#         print('queries:', queries.shape)
#         print('keys:', keys.shape)
#         print('values:', values.shape)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        
        # output shape: (batch_size * num_heads, 查询的个数, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)
#         print('output:', output.shape)
        output_concat = transpose_output(output, self.num_heads)
#         print('output_concat:', output_concat.shape)
        return self.W_o(output_concat)

qkv转换

python 复制代码
def transpose_qkv(X, num_heads):
    # 输入X的shape: (batch_size, 查询数/键值对数, num_hiddens)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3) # batch_size, num_heads, 查询数/ 键值对数, num_hiddens/num_heads
    # 这里是把batch_size和num_heads合并在一起了. 
    return X.reshape(-1, X.shape[2], X.shape[3]) # batch_size * num_heads, 查询/键值对数, num_hiddens/ num_heads

output转换

python 复制代码
def transpose_output(X, num_heads):
    # 逆转transpose_qkv的操作
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

构建注意力模块

python 复制代码
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.2)
attention.eval()

添加Bahdanau的decoder

python 复制代码
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs : (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # X : (batch_size, num_steps, vocab_size)
        X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens
#             print('query:', query.shape) # 4, 1, 16
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context: ', context.shape)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x: ', x.shape)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out:', out.shape)
#             print('hidden_state:', hidden_state.shape)
            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
            
#         print('---------------------------------')
        outputs = self.dense(torch.cat(outputs, dim=0))
#         print('解码器最终输出形状: ', outputs.shape)
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

训练

python 复制代码
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)



预测

python 复制代码
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est paresseux .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000
相关推荐
佛州小李哥8 分钟前
深度评测DeepSeek、ChatGPT O1和谷歌Gemini AI应用开发场景 - DeepSeek性能完胜!
人工智能·科技·ai·chatgpt·gemini·ai开发·deepseek
云边有个稻草人8 分钟前
AI重塑视觉艺术:DeepSeek与蓝耘通义万相2.1的图生视频奇迹
人工智能·音视频·deepseek·蓝耘智算·蓝耘通义万相2.1图生视频·deepseek的关键技术
埃菲尔铁塔_CV算法13 分钟前
C# WPF 基础知识学习(一)
图像处理·人工智能·学习·计算机视觉·c#·wpf
my烂笔头16 分钟前
深度学习 常见优化器
人工智能·深度学习
LDG_AGI36 分钟前
【深度学习】多元物料融合算法(一):量纲对齐常见方法
人工智能·深度学习·算法·机器学习·均值算法·哈希算法·启发式算法
KangkangLoveNLP41 分钟前
从Swish到SwiGLU:激活函数的进化与革命,qwen2.5应用的激活函数
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理·cnn
猫头虎1 小时前
阿里云操作系统控制台评测:国产AI+运维 一站式运维管理平台
运维·服务器·人工智能·阿里云·aigc·ai编程·ai写作
子洋1 小时前
AI 开发者必备:Vercel AI SDK 轻松搞定多厂商 AI 调用
前端·人工智能·后端
池央2 小时前
展望 AIGC 前景:通义万相 2.1 与蓝耘智算平台共筑 AI 生产力高地
人工智能·aigc
baikaishui3074 小时前
物联网时代的车队管理系统阐述
大数据·人工智能