PyTorch 实战:从 0 开始搭建 Transformer

  1. 导入必要的库

python

复制代码
import math
import torch
import torch.nn as nn
from LabmL_helpers.module import Module
from labml_n.utils import clone_module_List
from typing import Optional, List
from torch.utils.data import DataLoader, TensorDataset
from torch import optim
import torch.nn.functional as F
  1. Transformer 模型概述
    Transformer 是一种序列到序列的模型,通过自注意力机制并行处理整个序列,能同时考虑序列中的所有元素,并学习上下文之间的关系。其架构包括编码器和解码器部分,每部分都由多个相同的层组成,这些层包含自注意力机制、前馈神经网络,以及归一化和 Dropout 步骤。
  2. 核心公式
    • 自注意力计算:Attention(Q,K,V)=softmax(dkQKT)V,其中,Q、K、V分别是查询(Query)、键(Key)和值(Value)矩阵,dk是键的维度。
    • 多头注意力:将输入分割为多个头,分别计算注意力,然后将结果拼接起来。
    • 位置编码:由于 Transformer 不使用循环结构,因此引入位置编码来保留序列中的位置信息。
  3. 自注意力机制
    • 核心原理:计算句子在编码过程中每个位置上的注意力权重,然后以权重和的方式来计算整个句子的隐含向量表示。公式中,首先将 query 与 key 的转置做点积,然后将结果除以dk ,再进行 softmax 计算,最后将结果与 value 做矩阵乘法得到 output。除以dk是为了防止QKT过大导致 softmax 计算溢出,且可使QKT结果满足均值为 0,方差 1 的分布。QKT计算本质上是余弦相似度,可表示两个向量在方向上的相似度。
    • 实现

python

复制代码
import numpy as np
from math import sqrt
import torch
from torch import nn


class Self_Attention(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v):
        super(Self_Attention, self).__init__()
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)
        self._norm_fact = 1 / sqrt(dim_k)

    def forward(self, x):
        Q = self.q(x)  # Q: batch_size * seq_len * dim_k
        K = self.k(x)  # K: batch_size * seq_len * dim_k
        V = self.v(x)  # V: batch_size * seq_len * dim_v
        # Q * K.T() # batch_size * seq_len * seq_len
        atten = nn.Softmax(
            dim=-1)(torch.bmm(Q, K.permute(0, 2, 1))) * self._norm_fact
        # Q * K.T() * V # batch_size * seq_len * dim_v
        output = torch.bmm(atten, V)
        return output


X = torch.randn(4, 3, 2)
print(X)
self_atten = Self_Attention(2, 4, 5)  # input_dim:2, k_dim:4, v_dim:5
res = self_atten(X)
print(res.shape)  # [4,3,5]
  1. 多头注意力机制
    不同于只使用一个注意力池化,将输入x拆分为h份,独立计算h组不同的线性投影来得到各自的 QKV,然后并行计算注意力,最后将h个注意力池化拼接起来并通过另一个可学习的线性投影进行变换以产生输出。每个头可能关注输入的不同部分,可表示更复杂的函数。

python

复制代码
from math import sqrt
import torch
import torch.nn as nn


class Self_Attention_Muti_Head(nn.Module):
    # input : batch_size * seq_len * input_dim
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v, nums_head):
        super(Self_Attention_Muti_Head, self).__init__()
        assert dim_k % nums_head == 0
        assert dim_v % nums_head == 0
        self.q = nn.Linear(input_dim, dim_k)
        self.k = nn.Linear(input_dim, dim_k)
        self.v = nn.Linear(input_dim, dim_v)
        self.nums_head = nums_head
        self.dim_k = dim_k
        self.dim_v = dim_v
        self._norm_fact = 1 / sqrt(dim_k)

    def forward(self, x):
        Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //
                              self.nums_head)
        K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //
                              self.nums_head)
        V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v //
                              self.nums_head)
        print(x.shape)
        print(Q.size())
        atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0, 1, 3, 2)))  # Q * K.T() # batch_size * seq_len * seq_len
        output = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1)  # Q * K.T() * V # batch_size * seq_len * dim_v
        return output


x = torch.rand(1, 3, 4)
print(x)
atten = Self_Attention_Muti_Head(4, 4, 4, 2)
y = atten(x)
print(y.shape)
  1. 视觉注意力机制
    attention 机制本质是利用相关特征图学习权重分布,再用学出来的权重施加在原特征图上最后进行加权求和。计算机视觉上的注意力机制主要分为三种:空间域、通道域、混合域。
    • 空间域:将图片中的空间域信息做对应的空间变换,提取关键信息,对空间进行掩码的生成并打分,代表是 Spatial attention module。
    • 通道域:给每个通道上的信号增加一个权重,代表该通道与关键信息的相关度,权重越大相关度越高。对通道生成掩码 mask 进行打分,代表是 senet、channel attention module。
    • 混合域:空间域的注意力忽略了通道域中的信息,将每个通道的图片特征同等处理,这种做法会将空间域变换方法局限在原始特征提取阶段。
  2. 通道域注意力(SENet)
    通过全局池化提取通道权重,然后对特征图进行改变,得到加强后的特征图。

python

复制代码
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)  # 对应Squeeze操作
        y = self.fc(y).view(b, c, 1, 1)  # 对应Excitation操作
        return x * y.expand_as(x)
  1. 门控注意力机制(GCT,Gated Channel Transformation)
    GCT 是一种简单有效的通道间建模关系体系结构,能显著提高卷积网络在视觉任务的泛化能力。论文发现将门控机制放在 Conv 层前面训练效果最好。GCT 包含三个部分:
    • Global Context Embedding:设计了一种全局上下文嵌入模块,用于每个通道的全局上下文信息汇聚,公式为sc=αc∥xc∥2=αc{[∑i=1H∑j=1W(xci,j)2]+ϵ}21。
    • Channel Normalization:对第一步计算的 L2 进行规范化来构建神经元竞争关系,使用跨通道的特征规范化,公式为s^c=∥s∥2Csc=[(∑c=1Csc2)+ϵ]21Csc。
    • Gating Adaptation:加入门限机制,公式为x^c=xc[1+tanh(γcs^c+βc)] 。

python

复制代码
class GCT(nn.Module):
    def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):
        super(GCT, self).__init__()
        self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu

    def forward(self, x):
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3), keepdim=True) +
                         self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / \
                   (embedding.pow(2).mean(dim=1, keepdim=True) +
                    self.epsilon).pow(0.5)
        elif self.mode == 'l1':
            if not self.after_relu:
                _x = torch.abs(x)
            else:
                _x = x
            embedding = _x.sum((2, 3), keepdim=True) * self.alpha
            norm = self.gamma / \
                   (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
        gate = 1. + torch.tanh(embedding * norm + self.beta)
        return x * gate

GCT 建议添加在 Conv 层前,一般可以先冻结原来的模型,来训练 GCT,然后解冻再进行微调。

相关推荐
风象南6 分钟前
普通人用AI加持赚到的第一个100块
人工智能·后端
牛奶37 分钟前
2026年大模型怎么选?前端人实用对比
前端·人工智能·ai编程
牛奶39 分钟前
前端人为什么要学AI?
前端·人工智能·ai编程
哥布林学者2 小时前
高光谱成像(一)高光谱图像
机器学习·高光谱成像
地平线开发者2 小时前
SparseDrive 模型导出与性能优化实战
算法·自动驾驶
董董灿是个攻城狮3 小时前
大模型连载2:初步认识 tokenizer 的过程
算法
地平线开发者3 小时前
地平线 VP 接口工程实践(一):hbVPRoiResize 接口功能、使用约束与典型问题总结
算法·自动驾驶
罗西的思考3 小时前
AI Agent框架探秘:拆解 OpenHands(10)--- Runtime
人工智能·算法·机器学习
冬奇Lab4 小时前
OpenClaw 源码精读(2):Channel & Routing——一条消息如何找到它的 Agent?
人工智能·开源·源码阅读
冬奇Lab4 小时前
一天一个开源项目(第38篇):Claude Code Telegram - 用 Telegram 远程用 Claude Code,随时随地聊项目
人工智能·开源·资讯