pycharm注意力残差示例

文章目录

注意力残差这个概念前几天火了,还不太懂,先跟跟风。

示例

1、安装依赖

复制代码
pip install torch
pip install numpy # numpy是多维数组操作相关的库

2、代码:

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        # 定义 Q, K, V 的线性变换
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)
        self.out_linear = nn.Linear(embed_dim, embed_dim)  # 输出投影

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1. 线性映射并拆分多头
        # q: (batch, seq_len, embed_dim) -> (batch, num_heads, seq_len, head_dim)
        q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. 计算注意力分数 (Scaled Dot-Product Attention)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 3. 加权求和
        context = torch.matmul(attn_weights, v)  # (batch, num_heads, seq_len, head_dim)

        # 4. 合并多头并线性投影
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        attention_output = self.out_linear(context)

        return attention_output


class AttentionBlockWithResidual(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(AttentionBlockWithResidual, self).__init__()
        self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # --- 关键部分:注意力残差实现 ---

        # 方案 A: Post-LN (原始 Transformer 做法)
        # 1. 先计算注意力
        attn_out = self.attention(x, x, x, mask)
        # 2. 应用 Dropout (可选,通常在残差前或后)
        attn_out = self.dropout(attn_out)
        # 3. 残差连接: Input + Attention_Output
        residual_out = x + attn_out
        # 4. 层归一化
        output = self.norm(residual_out)

        return output

        # 方案 B: Pre-LN (现代 Transformer 常用,训练更稳定)
        # norm_x = self.norm(x)
        # attn_out = self.attention(norm_x, norm_x, norm_x, mask)
        # output = x + self.dropout(attn_out)
        # return output


# --- 测试代码 ---
if __name__ == "__main__":
    # 假设参数
    batch_size = 2
    seq_length = 10
    embed_dim = 512
    num_heads = 8

    # 随机输入数据
    x = torch.randn(batch_size, seq_length, embed_dim)

    # 初始化模块
    model = AttentionBlockWithResidual(embed_dim, num_heads)

    # 前向传播
    output = model(x)

    print(f"输入形状: {x.shape}")
    print(f"输出形状: {output.shape}")

    # 验证残差是否生效 (简单检查:输出不应等于纯注意力输出,也不应等于纯输入)
    # 理论上 output ≈ LayerNorm(x + Attention(x))
    print("注意力残差模块运行成功!")

输出内容:

复制代码
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
注意力残差模块运行成功!

结果不变是对的吗?

这是对的,因为核心指标就是维度守恒。

维度守恒:

输入:batch_size=2, seq_len=10, embed_dim=512

输出:batch_size=2, seq_len=10, embed_dim=512

结论:残差连接要求 X 和 F(X)维度必须一致,你的代码完美满足了这一点。这是构建深层 Transformer 堆叠层的基础。

相关推荐
xcbrand6 分钟前
文旅行业品牌策划公司找哪家
大数据·运维·人工智能·python
好家伙VCC13 分钟前
**发散创新:基于Rust的轻量级权限管理库设计与开源许可证实践**在现代分布式系统中,**权限控制(RBAC
java·开发语言·python·rust·开源
Dxy123931021629 分钟前
Python序列标注模型上下文纠错详解
开发语言·python
ZhengEnCi29 分钟前
P2H-Python字符串格式化完全指南-format和f-string的Python编程利器
python
HaiXCoder30 分钟前
python从入门到精通-第5章: 函数式编程 — Python的函数式风格
python
HaiXCoder36 分钟前
python从入门到精通-第0章: 思维模式碰撞
python
HaiXCoder36 分钟前
python从入门到精通-第3章: 数据结构 — Python的"瑞士军刀
python
Orange_sparkle39 分钟前
learn claude code学习记录-S02
java·python·学习
小郑加油41 分钟前
python学习Day1:python的安装与环境搭载
python·学习·小白记录,保姆式教程
Zewen PAN1 小时前
wsl安装pytorch
人工智能·pytorch·python