VLM-单头自注意力机制核心逻辑

该代码展现**「文本编码→词嵌入→自注意力计算」的全流程,展示工作流程**

后面附带个人理解及解析

python 复制代码
​
import torch.nn as nn
import torch

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T  # unnormalized attention weights
        attn_weights = torch.softmax(
            attn_scores / self.d_out_kq ** 0.5, dim=-1
        )
        print("attn_weights:", attn_weights.shape, attn_weights)

        context_vec = attn_weights @ values
        return context_vec

sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s
      in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

sentence_int = torch.tensor(
    [dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)
vocab_size = 50_000
torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)

torch.manual_seed(123)
# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4
sa = SelfAttention(d_in, d_out_kq, d_out_v)
print(sa(embedded_sentence))

​

一、整体流程概览

代码的核心目标是:让句子中的每个词通过「自注意力」结合所有词的信息,生成更具语义关联性的上下文向量。整体流程如下:

查询(Q)就像是在问 "我要找什么信息",键(K)是提供信息的 "标签",值(V)就是实际的信息内容。

  1. 定义自注意力模型SelfAttention类):实现 Q/K/V 计算、注意力权重归一化、上下文向量生成。
  2. 文本预处理:句子分词、构建词到索引的映射、将句子转换为整数编码。
  3. 词嵌入 :通过Embedding层将整数编码的词转换为低维稠密向量。
  4. 自注意力计算:将嵌入后的句子输入自注意力模型,输出每个词的上下文向量。

二、核心模块解析:SelfAttention类(自注意力核心)

自注意力的本质是「通过词与词之间的相似度(Q-K 匹配),给每个词分配不同权重,再对价值向量(V)加权求和」,最终让每个词都融合全局信息。

关键参数说明:
  • d_in:输入向量的维度(这里是词嵌入维度,后续设为 3)。
  • d_out_kq:Q(查询)和 K(键)的输出维度(必须相同,因为要计算 Q 与 K 的相似度)。
  • d_out_v:V(价值)的输出维度(可自由设定,后续设为 4)。
  • nn.Parameter:将张量标记为「模型可训练参数」,反向传播时会自动更新梯度。
python 复制代码
​
def __init__(self, d_in, d_out_kq, d_out_v):
    super().__init__()
    self.d_out_kq = d_out_kq  # Q和K的输出维度(必须一致,否则无法矩阵相乘)
    # 三个可学习的权重矩阵(nn.Parameter表示是模型需要训练的参数)
    self.W_query = nn.Parameter(torch.rand(d_in, d_out_kq))  # 查询(Q)的权重:d_in→d_out_kq
    self.W_key = nn.Parameter(torch.rand(d_in, d_out_kq))    # 键(K)的权重:d_in→d_out_kq
    self.W_value = nn.Parameter(torch.rand(d_in, d_out_v))   # 价值(V)的权重:d_in→d_out_v

​

2. forward方法:自注意力计算逻辑(核心)

forward是模型的前向传播函数,输入为x(嵌入后的句子),输出为每个词的上下文向量。步骤拆解如下:

步骤 1:计算 Q、K、V → 「特征投射」

目的是把原始嵌入向量(3 维)分别投射到 3 个不同的语义空间(Q/K 是 2 维,V 是 4 维),而不是保持原维度。

  • Q:"我(当前词)想找什么?"
  • K:"我(其他词)有什么?"
  • V:"我(其他词)的核心信息是什么?"

这里 Q 和 K 必须同维度(d_out_kq=2),因为后续要计算 "查询" 和 "键" 的匹配度(矩阵乘法要求 "前矩阵列数 = 后矩阵行数");但 V 的维度(d_out_v=4)可以和 Q/K 不同 ------ 这是你理解的第一个误区:Q/K 和 V 不需要同维度

python 复制代码
keys = x @ self.W_key      # K: [seq_len, d_in] @ [d_in, d_out_kq] → [seq_len, d_out_kq]
queries = x @ self.W_query # Q: [seq_len, d_in] @ [d_in, d_out_kq] → [seq_len, d_out_kq]
values = x @ self.W_value  # V: [seq_len, d_in] @ [d_in, d_out_v] → [seq_len, d_out_v]
步骤 2:计算 Q@K.T → 「匹配度打分」

Q 是 [6,2],K.T 是 [2,6](K 转置后,列数 = 原 K 的行数 = 6),矩阵乘法结果是 [6,6] 的attn_scores

  • 核心是:为每个词(行i)计算它与所有其他词(列j)的 "相关性分数"------ 分数越高,说明第j个词的信息对第i个词越重要。
cpp 复制代码
attn_scores = queries @ keys.T  # unnormalized attention weights

attn_scores每行 做 Softmax(dim=-1),得到attn_weights(仍为 [6,6])。

  • 目的:把 "相关性分数" 转换成 "权重"(每行和为 1),避免分数过大导致 Softmax 饱和(所以除以√d_out_kq做缩放)。
  • 比如attn_weights[0][4] = 0.3,代表 "第 0 个词(Life)对第 4 个词(is)的注意力权重是 30%"
cpp 复制代码
    attn_weights = torch.softmax(
            #attn_scores / self.d_out_kq ** 0.5,  # 缩放0
            #dim=-1  # 对最后一维归一化
            attn_scores / self.d_out_kq ** 0.5, dim=-1
        )
步骤 4:attn_weights @ V → 「加权求和」(你的核心疑问)

这一步的正确理解是:用注意力权重对所有词的 V 向量做 "加权求和",而不是 "将 qk 与 v 同维度"。

维度匹配的关键是:attn_weights列数 (6)= V 的行数 (6)------ 因为attn_weights的列对应 "所有词",V 的行也对应 "所有词",矩阵乘法的本质是:对于输出的第i个向量(上下文向量),它是 V 的所有行向量(每个词的价值向量),按attn_weights[i]的权重加权相加得到的。

cpp 复制代码
   context_vec = attn_weights @ values
        return context_vec
相关推荐
Yao.Li1 小时前
PVN3D ORT CUDA Custom Ops 实现与联调记录
人工智能·3d·具身智能
诺伦1 小时前
LocalClaw 在智能制造的新机会:6部门AI+电商政策下的工厂AI升级方案
人工智能·制造
小陈工3 小时前
Python Web开发入门(十七):Vue.js与Python后端集成——让前后端真正“握手言和“
开发语言·前端·javascript·数据库·vue.js·人工智能·python
墨染天姬7 小时前
【AI】端侧AIBOX可以部署哪些智能体
人工智能
AI成长日志7 小时前
【Agentic RL】1.1 什么是Agentic RL:从传统RL到智能体学习
人工智能·学习·算法
2501_948114247 小时前
2026年大模型API聚合平台技术评测:企业级接入层的治理演进与星链4SAPI架构观察
大数据·人工智能·gpt·架构·claude
小小工匠7 小时前
LLM - awesome-design-md 从 DESIGN.md 到“可对话的设计系统”:用纯文本驱动 AI 生成一致 UI 的新范式
人工智能·ui
黎阳之光8 小时前
黎阳之光:视频孪生领跑者,铸就中国数字科技全球竞争力
大数据·人工智能·算法·安全·数字孪生
小超同学你好8 小时前
面向 LLM 的程序设计 6:Tool Calling 的完整生命周期——从定义、决策、执行到观测回注
人工智能·语言模型
智星云算力8 小时前
本地GPU与租用GPU混合部署:混合算力架构搭建指南
人工智能·架构·gpu算力·智星云·gpu租用