该代码展现**「文本编码→词嵌入→自注意力计算」的全流程,展示工作流程**
后面附带个人理解及解析
python
import torch.nn as nn
import torch
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out_kq, d_out_v):
super().__init__()
self.d_out_kq = d_out_kq
self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))
def forward(self, x):
keys = x @ self.W_key
queries = x @ self.W_query
values = x @ self.W_value
attn_scores = queries @ keys.T # unnormalized attention weights
attn_weights = torch.softmax(
attn_scores / self.d_out_kq ** 0.5, dim=-1
)
print("attn_weights:", attn_weights.shape, attn_weights)
context_vec = attn_weights @ values
return context_vec
sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s
in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)
sentence_int = torch.tensor(
[dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)
vocab_size = 50_000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)
torch.manual_seed(123)
# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))
一、整体流程概览
代码的核心目标是:让句子中的每个词通过「自注意力」结合所有词的信息,生成更具语义关联性的上下文向量。整体流程如下:
查询(Q)就像是在问 "我要找什么信息",键(K)是提供信息的 "标签",值(V)就是实际的信息内容。
- 定义自注意力模型 (
SelfAttention类):实现 Q/K/V 计算、注意力权重归一化、上下文向量生成。 - 文本预处理:句子分词、构建词到索引的映射、将句子转换为整数编码。
- 词嵌入 :通过
Embedding层将整数编码的词转换为低维稠密向量。 - 自注意力计算:将嵌入后的句子输入自注意力模型,输出每个词的上下文向量。
二、核心模块解析:SelfAttention类(自注意力核心)
自注意力的本质是「通过词与词之间的相似度(Q-K 匹配),给每个词分配不同权重,再对价值向量(V)加权求和」,最终让每个词都融合全局信息。
关键参数说明:
d_in:输入向量的维度(这里是词嵌入维度,后续设为 3)。d_out_kq:Q(查询)和 K(键)的输出维度(必须相同,因为要计算 Q 与 K 的相似度)。d_out_v:V(价值)的输出维度(可自由设定,后续设为 4)。nn.Parameter:将张量标记为「模型可训练参数」,反向传播时会自动更新梯度。
python
def __init__(self, d_in, d_out_kq, d_out_v):
super().__init__()
self.d_out_kq = d_out_kq # Q和K的输出维度(必须一致,否则无法矩阵相乘)
# 三个可学习的权重矩阵(nn.Parameter表示是模型需要训练的参数)
self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq)) # 查询(Q)的权重:d_in→d_out_kq
self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq)) # 键(K)的权重:d_in→d_out_kq
self.W_value = nn.Parameter(torch.rand(d_in, d_out_v)) # 价值(V)的权重:d_in→d_out_v
2. forward方法:自注意力计算逻辑(核心)
forward是模型的前向传播函数,输入为x(嵌入后的句子),输出为每个词的上下文向量。步骤拆解如下:
步骤 1:计算 Q、K、V → 「特征投射」
目的是把原始嵌入向量(3 维)分别投射到 3 个不同的语义空间(Q/K 是 2 维,V 是 4 维),而不是保持原维度。
- Q:"我(当前词)想找什么?"
- K:"我(其他词)有什么?"
- V:"我(其他词)的核心信息是什么?"
这里 Q 和 K 必须同维度(d_out_kq=2),因为后续要计算 "查询" 和 "键" 的匹配度(矩阵乘法要求 "前矩阵列数 = 后矩阵行数");但 V 的维度(d_out_v=4)可以和 Q/K 不同 ------ 这是你理解的第一个误区:Q/K 和 V 不需要同维度。
python
keys = x @ self.W_key # K: [seq_len, d_in] @ [d_in, d_out_kq] → [seq_len, d_out_kq]
queries = x @ self.W_query # Q: [seq_len, d_in] @ [d_in, d_out_kq] → [seq_len, d_out_kq]
values = x @ self.W_value # V: [seq_len, d_in] @ [d_in, d_out_v] → [seq_len, d_out_v]
步骤 2:计算 Q@K.T → 「匹配度打分」
Q 是 [6,2],K.T 是 [2,6](K 转置后,列数 = 原 K 的行数 = 6),矩阵乘法结果是 [6,6] 的attn_scores。
- 核心是:为每个词(行
i)计算它与所有其他词(列j)的 "相关性分数"------ 分数越高,说明第j个词的信息对第i个词越重要。
cpp
attn_scores = queries @ keys.T # unnormalized attention weights
对attn_scores的每行 做 Softmax(dim=-1),得到attn_weights(仍为 [6,6])。
- 目的:把 "相关性分数" 转换成 "权重"(每行和为 1),避免分数过大导致 Softmax 饱和(所以除以
√d_out_kq做缩放)。 - 比如
attn_weights[0][4] = 0.3,代表 "第 0 个词(Life)对第 4 个词(is)的注意力权重是 30%"
cpp
attn_weights = torch.softmax(
#attn_scores / self.d_out_kq ** 0.5, # 缩放0
#dim=-1 # 对最后一维归一化
attn_scores / self.d_out_kq ** 0.5, dim=-1
)
步骤 4:attn_weights @ V → 「加权求和」(你的核心疑问)
这一步的正确理解是:用注意力权重对所有词的 V 向量做 "加权求和",而不是 "将 qk 与 v 同维度"。
维度匹配的关键是:attn_weights的列数 (6)= V 的行数 (6)------ 因为attn_weights的列对应 "所有词",V 的行也对应 "所有词",矩阵乘法的本质是:对于输出的第i个向量(上下文向量),它是 V 的所有行向量(每个词的价值向量),按attn_weights[i]的权重加权相加得到的。
cpp
context_vec = attn_weights @ values
return context_vec