pycharm注意力残差示例

文章目录

注意力残差这个概念前几天火了,还不太懂,先跟跟风。

示例

1、安装依赖

复制代码
pip install torch
pip install numpy # numpy是多维数组操作相关的库

2、代码:

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        # 定义 Q, K, V 的线性变换
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)
        self.out_linear = nn.Linear(embed_dim, embed_dim)  # 输出投影

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1. 线性映射并拆分多头
        # q: (batch, seq_len, embed_dim) -> (batch, num_heads, seq_len, head_dim)
        q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. 计算注意力分数 (Scaled Dot-Product Attention)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 3. 加权求和
        context = torch.matmul(attn_weights, v)  # (batch, num_heads, seq_len, head_dim)

        # 4. 合并多头并线性投影
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        attention_output = self.out_linear(context)

        return attention_output


class AttentionBlockWithResidual(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(AttentionBlockWithResidual, self).__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # --- 关键部分:注意力残差实现 ---

        # 方案 A: Post-LN (原始 Transformer 做法)
        # 1. 先计算注意力
        attn_out = self.attention(x, x, x, mask)
        # 2. 应用 Dropout (可选,通常在残差前或后)
        attn_out = self.dropout(attn_out)
        # 3. 残差连接: Input + Attention_Output
        residual_out = x + attn_out
        # 4. 层归一化
        output = self.norm(residual_out)

        return output

        # 方案 B: Pre-LN (现代 Transformer 常用,训练更稳定)
        # norm_x = self.norm(x)
        # attn_out = self.attention(norm_x, norm_x, norm_x, mask)
        # output = x + self.dropout(attn_out)
        # return output


# --- 测试代码 ---
if __name__ == "__main__":
    # 假设参数
    batch_size = 2
    seq_length = 10
    embed_dim = 512
    num_heads = 8

    # 随机输入数据
    x = torch.randn(batch_size, seq_length, embed_dim)

    # 初始化模块
    model = AttentionBlockWithResidual(embed_dim, num_heads)

    # 前向传播
    output = model(x)

    print(f"输入形状: {x.shape}")
    print(f"输出形状: {output.shape}")

    # 验证残差是否生效 (简单检查:输出不应等于纯注意力输出,也不应等于纯输入)
    # 理论上 output ≈ LayerNorm(x + Attention(x))
    print("注意力残差模块运行成功!")

输出内容:

复制代码
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
注意力残差模块运行成功!

结果不变是对的吗?

这是对的,因为核心指标就是维度守恒。

维度守恒:

输入:batch_size=2, seq_len=10, embed_dim=512

输出:batch_size=2, seq_len=10, embed_dim=512

结论:残差连接要求 X 和 F(X)维度必须一致,你的代码完美满足了这一点。这是构建深层 Transformer 堆叠层的基础。

相关推荐
2301_793804692 小时前
用Python和Twilio构建短信通知系统
jvm·数据库·python
B站_计算机毕业设计之家2 小时前
计算机毕业设计:Python当当网图书数据全链路处理平台 Django框架 爬虫 Pandas 可视化 大数据 大模型 书籍(建议收藏)✅
爬虫·python·机器学习·django·flask·pandas·课程设计
不要秃头的小孩2 小时前
力扣刷题——111.二叉树的最小深度
数据结构·python·算法·leetcode
我是鶸2 小时前
secml-malware python library 源码分析及实践
开发语言·python
进击的小头2 小时前
第15篇:MPC的发展方向及展望
python·算法
SugarFreeOixi3 小时前
MATLAB绘图风格记录NP类型
python·matlab·numpy
冥王丁B3 小时前
第31章 Prompt 与聊天模型笔记
笔记·python·prompt
左左右右左右摇晃3 小时前
Java笔记——包装类(自动拆装箱)
java·笔记·python
青瓷程序设计3 小时前
【果蔬识别系统】Python+深度学习+人工智能+算法模型+图像识别+2026原创
人工智能·python·深度学习