大模型之Block实现

Block 代码

python 复制代码
from ast import Tuple
from turtle import pos
from typing import Optional
from cycler import V
from fastapi import FastAPI
import torch.nn as nn
import torch
from LMConfig import LMConfig
import torch.nn.functional as F
import math


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps)
    
    def forward(self, x):
        return (self.weight * self._norm(x.float())).type_as(x)
        
def precompute_pos_cis(dim: int, end: int = int(32*1024), freq_theta = 1e6):
    freq = 1.0 / (freq_theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim)) # 根据输入x的维度得到频率向量
    seqlen = torch.arange(end, device=freq.device) # 输入的token数量
    freq = torch.outer(seqlen, freq).float() # 外积得到(end, dim//2),每个token有不同的频率向量
    pos_cis = torch.polar(torch.ones_like(freq), freq) # 复数向量 (end, dim//2),每个元素为 1*cos(freq) + sin(freq)i
    print(f"复数频率向量-形状为:{pos_cis.shape}\n")
    print(f"复数频率向量-第一个位置的频率向量:{pos_cis[0]}\n")
    print(f"复数频率向量-第二个位置的频率向量:{pos_cis[1]}\n")

def apply_RoPE(xq: torch.Tensor, xk: torch.Tensor, pos_cis):
    """RoPE

    Args:
        xq (_type_): (batch, seqlen, heads, dim)
        xk (_type_): (batch, seqlen, k_heads, dim)
        pos_cis (_type_): (seqlen, dim//2)
    """
    def unite_shape(pos_cis, x):
        x_dim = x.ndim
        assert 1 < x_dim
        assert pos_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == x.ndim-1 else 1 for i,d in enumerate(x.shape)]
        return pos_cis.view(*shape)
    # (batch, seqlen, heads, dim) 通过reshape (batch, seqlen, heads, dim//2, 2)
    # 然后通过view_as_complex让自身矩阵变成复数形式,即 (batch, seqlen, heads, dim//2),dim//2的元素已经是复数形式了
    print(f"xq的shape:{xq.shape}\n")
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).float() # (batch, seqlen, 1, dim//2)
    print(f"xq_的shape:{xq_.shape}\n")
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).float()
    pos_cis = unite_shape(pos_cis, xq_)
    print(f"对齐后的频率向量-shape:{pos_cis.shape}\n") # (1, seqlen, 1, dim//2)
    # torch.view_as_real(xq_ * pos_cis) 广播后,将复数向量展开为实数 (,dim//2)->(,dim//2,2)
    # flatten(3) 将包含索引为3的维度,及后面的维度进行合并,(,dim//2,2) -> (,dim//2*2) -> (,dim)
    xq_pos = torch.view_as_real(xq_ * pos_cis).flatten(3).float() 
    xk_pos = torch.view_as_real(xk_ * pos_cis).flatten(3).float()
    print(f"xq_pos的shape:{xq_pos.shape}")
    
    return xq_pos.type_as(xq), xk_pos.type_as(xk)
    
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x # (batch, seqlen, kv_heads, dim)
    return (
        x[:, :, :, None, :] # (batch, seqlen, kv_heads, none, dim)
        .expand(bs, slen, n_kv_heads, n_rep, head_dim) # (batch, seqlen, kv_heads, n_rep, dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim) # (batch, seqlen, kv_heads, n_rep * dim)
    )

    
class Attention(nn.Module):
    def __init__(self, args: LMConfig):
        super().__init__()
        # 1.配置 qkv 头和每个头的维度
        self.n_kv_heads = args.n_kv_heads if args.n_kv_heads else args.n_heads
        self.n_q_heads = args.n_heads
        assert self.n_q_heads % self.n_kv_heads == 0
        self.n_rep = self.n_q_heads // self.n_kv_heads
        self.n_dim = args.dim // args.n_heads
        
        # 2.定义 qkv 权重矩阵,以及对输出进行投影的 o 矩阵
        self.wq = nn.Linear(args.dim, self.n_dim * self.n_q_heads)
        self.wk = nn.Linear(args.dim, self.n_dim * self.n_kv_heads)
        self.wv = nn.Linear(args.dim, self.n_dim * self.n_kv_heads)
        self.wo = nn.Linear(self.n_dim * self.n_q_heads, args.dim)
        
        # 3.Dropout
        self.atten_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        
        # 4.生成因果掩码矩阵
        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float('-inf'))
        mask = torch.triu(mask, diagonal=1)
        print(f"掩码矩阵mask-第一行元素:{mask[0]}\n")
        self.register_buffer("mask", mask, persistent = False)

    def forward(self, x: torch.Tensor, 
                pos_cis: torch.Tensor, 
                past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
                use_cache = False
                ):     
        # 1.输入x通过qkv的权重矩阵进行映射得到qk,然后进行广播对齐维度,注意seqlen是当前x的序列长度
        batch, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        n_q_shape_tuple = (batch, seqlen, self.n_q_heads, self.n_dim)
        n_kv_shape_tuple = (batch, seqlen, self.n_kv_heads, self.n_dim)
        xq = xq.view(*n_q_shape_tuple)
        xk = xk.view(*n_kv_shape_tuple)
        
        # 2.应用位置编码
        xq, xk = apply_RoPE(xq=xq,xk=xk,pos_cis=pos_cis)
        
        # 3.对qkv维度进行旋转,以便于进行点积
        xq, xk, xv = (
            xq.transpose(1, 2),
            repeat_kv(xk, self.n_rep).transpose(1, 2),
            repeat_kv(xv, self.n_rep).transpose(1, 2)
        )
        
        # 4. 进行QKT点积,然后进行mask,softmax,最后加权求和
        scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.n_dim) # 点积后为 (batch, q_heads, seqlen, seqlen)
        scores = self.mask[:, :, seqlen, seqlen] # (batch, q_heads, seqlen, seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = scores @ xv
        
        output = output.transpose(1, 2).reshape(batch, seqlen, -1)
        output = self.resid_dropout(self.wo(output))
        return output, past_kv
            

class FeedForward(nn.Module):
    def __init__(self, args: LMConfig):
        super().__init__()
        if args.hidden_dim is None:
            hidden_dim = args.dim * 4
            hidden_dim = int(2 * hidden_dim / 3)
            args.hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of) 
        self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
        self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
        self.dropout = nn.Dropout(args.dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.w2(F.silu(self.w1(x))) * self.w3(x))
    
class MiniModelBlock(nn.Module):
    def __init__(self, layer_id: int, args: LMConfig):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)

        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) # attention的输入有个归一化
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) # ffn的输入有个归一化
        self.feed_forward = FeedForward(args)

    def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
        h_attn, _ = self.attention(
            self.attention_norm(x), # 输入attention之前就有归一化
            pos_cis,
            past_key_value=past_key_value,
            use_cache=use_cache
        )
        h = x + h_attn # 残差
        out = h + self.feed_forward(self.ffn_norm(h)) # ffn输入有归一化,对输出进行残差
        return out
        
相关推荐
亭亦青1 小时前
RTX 5056Ti适配PyTorch:安装步骤与依赖冲突解决
人工智能·pytorch
Dr.Kun1 小时前
【鲲码园Python】基于pytorch的鱼品种分类系统(31类)
pytorch·python·分类
Petrichor_H_1 小时前
DAY 43 复习日
开发语言·python
_codemonster1 小时前
深度学习实战(基于pytroch)系列(四十一)长短期记忆(LSTM)pytorch简洁实现
pytorch·深度学习·lstm
BoBoZz191 小时前
OrientedArrow 在两个随机生成的点之间绘制一根带箭头的线,以可视化一个向量
python·vtk·图形渲染·图形处理
机器学习之心1 小时前
198种组合算法+优化TCN时间卷积神经网络+SHAP分析+新数据预测+多输出!深度学习可解释分析,强烈安利,粉丝必备!
深度学习·算法·shap分析·tcn时间卷积神经网络
m0_738120721 小时前
渗透测试——Kioptrix5靶机渗透测试详细教程
网络·python·安全·web安全·ssh
z***94841 小时前
Java进阶07 嵌套类
java·开发语言·python
橘子编程1 小时前
仓颉语言:华为新一代编程利器
java·c语言·开发语言·数据库·python·青少年编程