VLM-单头自注意力机制核心逻辑

该代码展现**「文本编码→词嵌入→自注意力计算」的全流程,展示工作流程**

后面附带个人理解及解析

python 复制代码
​
import torch.nn as nn
import torch

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T  # unnormalized attention weights
        attn_weights = torch.softmax(
            attn_scores / self.d_out_kq ** 0.5, dim=-1
        )
        print("attn_weights:", attn_weights.shape, attn_weights)

        context_vec = attn_weights @ values
        return context_vec

sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s
      in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

sentence_int = torch.tensor(
    [dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)
vocab_size = 50_000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)

torch.manual_seed(123)
# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

​

一、整体流程概览

代码的核心目标是:让句子中的每个词通过「自注意力」结合所有词的信息,生成更具语义关联性的上下文向量。整体流程如下:

查询(Q)就像是在问 "我要找什么信息",键(K)是提供信息的 "标签",值(V)就是实际的信息内容。

  1. 定义自注意力模型SelfAttention类):实现 Q/K/V 计算、注意力权重归一化、上下文向量生成。
  2. 文本预处理:句子分词、构建词到索引的映射、将句子转换为整数编码。
  3. 词嵌入 :通过Embedding层将整数编码的词转换为低维稠密向量。
  4. 自注意力计算:将嵌入后的句子输入自注意力模型,输出每个词的上下文向量。

二、核心模块解析:SelfAttention类(自注意力核心)

自注意力的本质是「通过词与词之间的相似度(Q-K 匹配),给每个词分配不同权重,再对价值向量(V)加权求和」,最终让每个词都融合全局信息。

关键参数说明:
  • d_in:输入向量的维度(这里是词嵌入维度,后续设为 3)。
  • d_out_kq:Q(查询)和 K(键)的输出维度(必须相同,因为要计算 Q 与 K 的相似度)。
  • d_out_v:V(价值)的输出维度(可自由设定,后续设为 4)。
  • nn.Parameter:将张量标记为「模型可训练参数」,反向传播时会自动更新梯度。
python 复制代码
​
def __init__(self, d_in, d_out_kq, d_out_v):
    super().__init__()
    self.d_out_kq = d_out_kq  # Q和K的输出维度(必须一致,否则无法矩阵相乘)
    # 三个可学习的权重矩阵(nn.Parameter表示是模型需要训练的参数)
    self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))  # 查询(Q)的权重:d_in→d_out_kq
    self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))    # 键(K)的权重:d_in→d_out_kq
    self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))   # 价值(V)的权重:d_in→d_out_v

​

2. forward方法:自注意力计算逻辑(核心)

forward是模型的前向传播函数,输入为x(嵌入后的句子),输出为每个词的上下文向量。步骤拆解如下:

步骤 1:计算 Q、K、V → 「特征投射」

目的是把原始嵌入向量(3 维)分别投射到 3 个不同的语义空间(Q/K 是 2 维,V 是 4 维),而不是保持原维度。

  • Q:"我(当前词)想找什么?"
  • K:"我(其他词)有什么?"
  • V:"我(其他词)的核心信息是什么?"

这里 Q 和 K 必须同维度(d_out_kq=2),因为后续要计算 "查询" 和 "键" 的匹配度(矩阵乘法要求 "前矩阵列数 = 后矩阵行数");但 V 的维度(d_out_v=4)可以和 Q/K 不同 ------ 这是你理解的第一个误区:Q/K 和 V 不需要同维度

python 复制代码
keys = x @ self.W_key      # K: [seq_len, d_in] @ [d_in, d_out_kq] → [seq_len, d_out_kq]
queries = x @ self.W_query # Q: [seq_len, d_in] @ [d_in, d_out_kq] → [seq_len, d_out_kq]
values = x @ self.W_value  # V: [seq_len, d_in] @ [d_in, d_out_v] → [seq_len, d_out_v]
步骤 2:计算 Q@K.T → 「匹配度打分」

Q 是 [6,2],K.T 是 [2,6](K 转置后,列数 = 原 K 的行数 = 6),矩阵乘法结果是 [6,6] 的attn_scores

  • 核心是:为每个词(行i)计算它与所有其他词(列j)的 "相关性分数"------ 分数越高,说明第j个词的信息对第i个词越重要。
cpp 复制代码
attn_scores = queries @ keys.T  # unnormalized attention weights

attn_scores每行 做 Softmax(dim=-1),得到attn_weights(仍为 [6,6])。

  • 目的:把 "相关性分数" 转换成 "权重"(每行和为 1),避免分数过大导致 Softmax 饱和(所以除以√d_out_kq做缩放)。
  • 比如attn_weights[0][4] = 0.3,代表 "第 0 个词(Life)对第 4 个词(is)的注意力权重是 30%"
cpp 复制代码
    attn_weights = torch.softmax(
            #attn_scores / self.d_out_kq ** 0.5,  # 缩放0
            #dim=-1  # 对最后一维归一化
            attn_scores / self.d_out_kq ** 0.5, dim=-1
        )
步骤 4:attn_weights @ V → 「加权求和」(你的核心疑问)

这一步的正确理解是:用注意力权重对所有词的 V 向量做 "加权求和",而不是 "将 qk 与 v 同维度"。

维度匹配的关键是:attn_weights列数 (6)= V 的行数 (6)------ 因为attn_weights的列对应 "所有词",V 的行也对应 "所有词",矩阵乘法的本质是:对于输出的第i个向量(上下文向量),它是 V 的所有行向量(每个词的价值向量),按attn_weights[i]的权重加权相加得到的。

cpp 复制代码
   context_vec = attn_weights @ values
        return context_vec
相关推荐
Lun3866buzha2 小时前
轮胎胎面花纹识别与分类:基于solo_r50_fpn模型的实现与优化
人工智能·分类·数据挖掘
zhangdawei8382 小时前
英伟达GB200,GB300和普通服务器如dell R740xd有什么区别?
运维·服务器·人工智能
Mintopia2 小时前
意图OS是未来软件形态,它到底解决了什么问题?
人工智能·react native·前端工程化
Mintopia2 小时前
🤖 AI 决策 + 意图OS:未来软件形态的灵魂共舞
前端·人工智能·react native
万行2 小时前
机器学习&第一章
人工智能·python·机器学习·flask·计算机组成原理
实战项目2 小时前
基于PyTorch的卷积神经网络花卉识别系统
人工智能·pytorch·cnn
shangjian0072 小时前
AI大模型-机器学习-算法-线性回归
人工智能·算法·机器学习
zuozewei2 小时前
零基础 | 一文速通 AI 大模型常见术语
人工智能
山海青风2 小时前
图像识别零基础实战入门 3 第一次训练图像分类模型
图像处理·人工智能·分类