python
复制代码
from ast import Tuple
from turtle import pos
from typing import Optional
from cycler import V
from fastapi import FastAPI
import torch.nn as nn
import torch
from LMConfig import LMConfig
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps)
def forward(self, x):
return (self.weight * self._norm(x.float())).type_as(x)
def precompute_pos_cis(dim: int, end: int = int(32*1024), freq_theta = 1e6):
freq = 1.0 / (freq_theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim)) # 根据输入x的维度得到频率向量
seqlen = torch.arange(end, device=freq.device) # 输入的token数量
freq = torch.outer(seqlen, freq).float() # 外积得到(end, dim//2),每个token有不同的频率向量
pos_cis = torch.polar(torch.ones_like(freq), freq) # 复数向量 (end, dim//2),每个元素为 1*cos(freq) + sin(freq)i
print(f"复数频率向量-形状为:{pos_cis.shape}\n")
print(f"复数频率向量-第一个位置的频率向量:{pos_cis[0]}\n")
print(f"复数频率向量-第二个位置的频率向量:{pos_cis[1]}\n")
def apply_RoPE(xq: torch.Tensor, xk: torch.Tensor, pos_cis):
"""RoPE
Args:
xq (_type_): (batch, seqlen, heads, dim)
xk (_type_): (batch, seqlen, k_heads, dim)
pos_cis (_type_): (seqlen, dim//2)
"""
def unite_shape(pos_cis, x):
x_dim = x.ndim
assert 1 < x_dim
assert pos_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == x.ndim-1 else 1 for i,d in enumerate(x.shape)]
return pos_cis.view(*shape)
# (batch, seqlen, heads, dim) 通过reshape (batch, seqlen, heads, dim//2, 2)
# 然后通过view_as_complex让自身矩阵变成复数形式,即 (batch, seqlen, heads, dim//2),dim//2的元素已经是复数形式了
print(f"xq的shape:{xq.shape}\n")
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).float() # (batch, seqlen, 1, dim//2)
print(f"xq_的shape:{xq_.shape}\n")
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).float()
pos_cis = unite_shape(pos_cis, xq_)
print(f"对齐后的频率向量-shape:{pos_cis.shape}\n") # (1, seqlen, 1, dim//2)
# torch.view_as_real(xq_ * pos_cis) 广播后,将复数向量展开为实数 (,dim//2)->(,dim//2,2)
# flatten(3) 将包含索引为3的维度,及后面的维度进行合并,(,dim//2,2) -> (,dim//2*2) -> (,dim)
xq_pos = torch.view_as_real(xq_ * pos_cis).flatten(3).float()
xk_pos = torch.view_as_real(xk_ * pos_cis).flatten(3).float()
print(f"xq_pos的shape:{xq_pos.shape}")
return xq_pos.type_as(xq), xk_pos.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x # (batch, seqlen, kv_heads, dim)
return (
x[:, :, :, None, :] # (batch, seqlen, kv_heads, none, dim)
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # (batch, seqlen, kv_heads, n_rep, dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) # (batch, seqlen, kv_heads, n_rep * dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
# 1.配置 qkv 头和每个头的维度
self.n_kv_heads = args.n_kv_heads if args.n_kv_heads else args.n_heads
self.n_q_heads = args.n_heads
assert self.n_q_heads % self.n_kv_heads == 0
self.n_rep = self.n_q_heads // self.n_kv_heads
self.n_dim = args.dim // args.n_heads
# 2.定义 qkv 权重矩阵,以及对输出进行投影的 o 矩阵
self.wq = nn.Linear(args.dim, self.n_dim * self.n_q_heads)
self.wk = nn.Linear(args.dim, self.n_dim * self.n_kv_heads)
self.wv = nn.Linear(args.dim, self.n_dim * self.n_kv_heads)
self.wo = nn.Linear(self.n_dim * self.n_q_heads, args.dim)
# 3.Dropout
self.atten_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
# 4.生成因果掩码矩阵
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float('-inf'))
mask = torch.triu(mask, diagonal=1)
print(f"掩码矩阵mask-第一行元素:{mask[0]}\n")
self.register_buffer("mask", mask, persistent = False)
def forward(self, x: torch.Tensor,
pos_cis: torch.Tensor,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache = False
):
# 1.输入x通过qkv的权重矩阵进行映射得到qk,然后进行广播对齐维度,注意seqlen是当前x的序列长度
batch, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
n_q_shape_tuple = (batch, seqlen, self.n_q_heads, self.n_dim)
n_kv_shape_tuple = (batch, seqlen, self.n_kv_heads, self.n_dim)
xq = xq.view(*n_q_shape_tuple)
xk = xk.view(*n_kv_shape_tuple)
# 2.应用位置编码
xq, xk = apply_RoPE(xq=xq,xk=xk,pos_cis=pos_cis)
# 3.对qkv维度进行旋转,以便于进行点积
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
# 4. 进行QKT点积,然后进行mask,softmax,最后加权求和
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.n_dim) # 点积后为 (batch, q_heads, seqlen, seqlen)
scores = self.mask[:, :, seqlen, seqlen] # (batch, q_heads, seqlen, seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = scores @ xv
output = output.transpose(1, 2).reshape(batch, seqlen, -1)
output = self.resid_dropout(self.wo(output))
return output, past_kv
class FeedForward(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
if args.hidden_dim is None:
hidden_dim = args.dim * 4
hidden_dim = int(2 * hidden_dim / 3)
args.hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.dropout = nn.Dropout(args.dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w2(F.silu(self.w1(x))) * self.w3(x))
class MiniModelBlock(nn.Module):
def __init__(self, layer_id: int, args: LMConfig):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) # attention的输入有个归一化
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) # ffn的输入有个归一化
self.feed_forward = FeedForward(args)
def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
h_attn, _ = self.attention(
self.attention_norm(x), # 输入attention之前就有归一化
pos_cis,
past_key_value=past_key_value,
use_cache=use_cache
)
h = x + h_attn # 残差
out = h + self.feed_forward(self.ffn_norm(h)) # ffn输入有归一化,对输出进行残差
return out