计算机基础·cs336·损失函数,优化器,调度器,数据处理和模型加载保存

Entropy:

H ( x ) = x ⋅ l o g x H(x)=x\cdot logx H(x)=x⋅logx

BCE:二分类损失

L = − y ⋅ l o g y p r e d + ( 1 − y ) ⋅ l o g ( 1 − y p r e d ) L=- y\cdot logy_{pred}+(1-y) \cdot log(1-y_{pred}) L=−y⋅logypred+(1−y)⋅log(1−ypred)

  • 一句话概括:y为正,看预测正类的概率;y为负类,预测负类的概率
  • 注意有一个负号

Cross Entropy:

朴素实现

  • y p r e d y_{pred} ypred:(b,...,n)
  • y y y:(b,...)
  • 将y零一向量化至(b,...,n)维度,然后对于每一个维度运行BCE二分类损失算法

缺点:无用计算,大部分维度都是0,只有维度为1的情况才有运算

Log-sum-exp实现

  • CE等价于对于标签所在维度进行熵的运算
  • 经过(2)的等价操作,我们只需要计算(3)中的log-sum-exp项 ,然后减去标签维度的输出 (注意这里不是概率)即可。
    ℓ = − log ⁡ ( S o f t m a x ( o ) y ) = − log ⁡ ( exp ⁡ ( o y ) ∑ j exp ⁡ ( o j ) ) ℓ = − ( log ⁡ ( exp ⁡ ( o y ) ) − log ⁡ ∑ j exp ⁡ ( o j ) ) ℓ = log ⁡ ( ∑ j exp ⁡ ( o j ) ) ⏟ LogSumExp 项 − o y \begin{align} &\ell=-\log\left(\mathrm{Softmax}(o)y\right)=-\log\left(\frac{\exp(o_y)}{\sum_j\exp(o_j)}\right)\\ & \ell=-\left(\log(\exp(o{y}))-\log\sum_{j}\exp(o_{j})\right) \\ & \ell=\underbrace{\log\left(\sum_{j}\exp(o_{j})\right)}{\text{LogSumExp 项}}-o{y} \end{align} ℓ=−log(Softmax(o)y)=−log(∑jexp(oj)exp(oy))ℓ=−(log(exp(oy))−logj∑exp(oj))ℓ=LogSumExp 项 log(j∑exp(oj))−oy
  • 但是直接计算log-sum-exp项的数值不稳定,例如某些输出可能很大>89,造成inf移除,有的情况下,输出都很小,导致分母接近于0。
  • 我们对每一个输出的结果减去最大值 m m m得到稳定计算log-sum-exp项的公式:
    L o g S u m E x p ( o ) = log ⁡ ( ∑ exp ⁡ ( o j − M + M ) ) L o g S u m E x p ( o ) = log ⁡ ( exp ⁡ ( M ) ⋅ ∑ exp ⁡ ( o j − M ) ) L o g S u m E x p ( o ) = M + log ⁡ ∑ exp ⁡ ( o j − M ) \begin{align}\begin{gathered} &\mathrm{LogSumExp}(o)=\log\left(\sum\exp(o_{j}-M+M)\right) \\ &\mathrm{LogSumExp}(o)=\log\left(\exp(M)\cdot\sum\exp(o_{j}-M)\right) \\ &\mathrm{LogSumExp}(o)=M+\log\sum\exp(o_{j}-M) \end{gathered}\end{align} LogSumExp(o)=log(∑exp(oj−M+M))LogSumExp(o)=log(exp(M)⋅∑exp(oj−M))LogSumExp(o)=M+log∑exp(oj−M)
python 复制代码
def cross_entropy_loss(logits,targets):
    """
    logits : b ... n
    targets : b ... 
    """
    m = torch.max(logits,dim=-1,keepdim=True).values# b ... 1
    log_sum_exp = torch.log(torch.sum(torch.exp(logits-m),dim=-1,keepdim=True))# b ... 1
    log_sum_exp = log_sum_exp + m # b ... 1
    target_logits = torch.gather(logits,dim=-1,index=targets.unsqueeze(-1))# b ... 1
    loss = log_sum_exp - target_logits # b ... 1
    return loss.mean()# 对batch和seq维度求平均

Adam和AdamW优化器原理

  • m t m_t mt代表第t次优化的动量,就是历史梯度的加权一阶矩估计
  • v t v_t vt代表第t次优化的缩放系数,就是历史梯度的加权二阶矩估计
    m t = β 1 m t − 1 + ( 1 − β 1 ) g t v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 \begin{align} m_t&=\beta_1m_{t-1}+(1-\beta_1)g_t \\ v_{t}&=\beta_{2}v_{t-1}+(1-\beta_{2})g_{t}^{2} \end{align} mtvt=β1mt−1+(1−β1)gt=β2vt−1+(1−β2)gt2
  • 冷启动:当优化次数t区域无穷时,没有影响;当t较小时,原来的动力和缩放系数会被适当增大
    m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t=\frac{m_t}{1-\beta_1^t},\quad\hat{v}_t=\frac{v_t}{1-\beta_2^t} m^t=1−β1tmt,v^t=1−β2tvt

AdamW优化器

  • 在原有基础上解决了权重衰减的一些问题,在更新时单独减去权重衰减 η λ θ t \eta\lambda\theta_t ηλθt系数。
  • 参数更新公式:先更新Adam,再更新权重衰减
    Δ θ t = η m ^ t v ^ t + ϵ θ t + 1 = θ t − Δ θ t − η λ θ t ⏟ 解耦的衰减项 \Delta\theta_t=\eta\frac{\hat{m}t}{\sqrt{\hat{v}t}+\epsilon}\\ \theta{t+1}=\theta_t-\Delta\theta_t-\underbrace{\eta\lambda\theta_t}{\text{解耦的衰减项}} Δθt=ηv^t +ϵm^tθt+1=θt−Δθt−解耦的衰减项 ηλθt

Optimizer的原理

  • self.param_groups:一般只有一个,就对应你传入的parameters()和学习率参数
  • super().init(params,defaults)。
python 复制代码
from torch.optim import Optimizer
"""Adam的实现,无性能优化版本"""
class AdamW(Optimizer):
    def __init__(self,params,lr=1e-3,betas=(0.9,0.999),eps=1e-8,weight_decay=0.01):
        defaults = dict(lr=lr,betas=betas,eps=eps,weight_decay=weight_decay)
        super().__init__(params,defaults)# defaults 会被复制到每个 param_group
        # self.state(dict),self.param_groups(dict) 
    def step(self,closure=None):
        # 标准接口
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            params = group['params']
            lr = group['lr']
            beta1, beta2 = group['betas']
            eps = group['eps']
            weight_decay = group['weight_decay']
            for p in params:
                if p.grad is None:
                    continue
                
                # g_t
                grad = p.grad.data

                state = self.state[p] 
                if len(state)==0:
                    state['t'] = 0
                    # memory_format=torch.preserve_format用于内存优化
                    state['m'] = torch.zeros_like(p,memory_format=torch.preserve_format)
                    state['v'] = torch.zeros_like(p,memory_format=torch.preserve_format)
                
                # Adam 更新公式:分别更新一阶矩和二阶矩
                state['t'] += 1
                state['m'] = beta1*state['m'] + (1-beta1)*grad
                state['v'] = beta2*state['v'] + (1-beta2)*grad*grad

                # Cold Start:不能保存
                m_hat = state['m']/(1-beta1**state['t'])
                v_hat = state['v']/(1-beta2**state['t'])

                # 先使用Adam公式,再使用权重衰减
                adam_term = lr*m_hat/(torch.sqrt(v_hat)+eps)

                # 权重衰减
                if weight_decay is not None:
                    weight_decay_term = lr*weight_decay*p.data
                else: 
                    weight_decay_term = 0
                
                p.data -= adam_term + weight_decay_term
        return loss

可变学习率

在训练函数中使用下列函数获得新的学习率,然后替代优化器中的学习率即可

α ( t ) = { α max ⁡ ⋅ t T w , 0 ≤ t < T w α min ⁡ + 1 2 ( 1 + cos ⁡ ( π ⋅ t − T w T c − T w ) ) ( α max ⁡ − α min ⁡ ) , T w ≤ t ≤ T c α min ⁡ , t > T c \alpha(t)= \begin{cases} \alpha_{\max}\cdot\frac{t}{T_w}, & 0\leq t<T_w \\ \alpha_{\min}+\frac{1}{2}\left(1+\cos\left(\pi\cdot\frac{t-T_w}{T_c-T_w}\right)\right)(\alpha_{\max}-\alpha_{\min}), & T_w\leq t\leq T_c \\ \alpha_{\min}, & t>T_c & \end{cases} α(t)=⎩ ⎨ ⎧αmax⋅Twt,αmin+21(1+cos(π⋅Tc−Twt−Tw))(αmax−αmin),αmin,0≤t<TwTw≤t≤Tct>Tc

Warmup

  • 一开始模型会进行热身,学习率会比较低,缓慢提高,防止随机初始化的参数在遇到高学习率时表现非常不稳定的情况

Cosine退火

  • 在稳定更新参数的时候,学习率缓慢减小

训练尾部

  • 保持较低学习率
python 复制代码
import math
def get_lr_cosine_schedule_with_warmup(
    it,
    max_lr,
    min_lr,
    warmup_iters,
    cosine_schedule_iters,
):
    # 1. warmup 阶段
    if it < warmup_iters:
        return max_lr * it / warmup_iters

    # 2. 退火结束后
    if it > cosine_schedule_iters:
        return min_lr

    # 3. cosine decay 阶段
    decay_ratio = (it - warmup_iters) / (cosine_schedule_iters - warmup_iters)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    lr = min_lr + coeff * (max_lr - min_lr)
    return lr

梯度裁剪

  • 为了避免模型的梯度爆炸,有时需要对模型参数的梯度进行一定裁剪,保持安全值范围内容
  • 对所有参数的梯度进行norm-2计算,然后进行裁剪
    g n e w = g o l d × M ∥ g ∥ 2 + ϵ g_{new}=g_{old}\times\frac{M}{\|g\|_2+\epsilon} gnew=gold×∥g∥2+ϵM
python 复制代码
def clip_grad_norm(parameters,max_norm,eps=1e-5):
    parameters = [p for p in parameters if p.grad is not None ]
    if parameters is None:
        return 
    
    total_norm = 0
    for p in parameters:
        # 在step之前调用
        param_norm = torch.norm(p.grad.detach(),2)
        total_norm += param_norm.item() ** 2 # sum{p^2}

    total_norm = total_norm ** 0.5 
    if total_norm > max_norm:
        for p in parameters:
            p.grad.detach().mul_(max_norm/(total_norm+eps))

数据处理

  • 训练任务:预测下一个词
  • 训练集和标签定义: 训练集 [ x i , . . . x n ] [x_i,...x_n] [xi,...xn],标签: [ x i + 1 , . . . x n ] [x_{i+1},...x_n] [xi+1,...xn]
  • 不要直接使用list读取完整文本,使用np.memmap建立内存映射,需要时候读取!
python 复制代码
import numpy as np
import numpy.typing as npt
import torch
def get_batch(dataset:npt.NDArray,batch_size,max_seq_length,device):
    # dataset:ndarray(list)
    n = len(dataset)
    # max_idx max_idx + max_seq_length-1 + 1 <= len(dataset)-1 
    max_idx = n - max_seq_length - 1
    start_indices = np.random.randint(0,max_idx+1,size=(batch_size,))

    # torch.stack:从某一个维度堆叠tensor
    x_batch = torch.stack(
        [torch.from_numpy(dataset[st_idx:st_idx+max_seq_length])
        for st_idx in start_indices]
    )
    y_batch = torch.stack(
        [torch.from_numpy(dataset[st_idx+1:st_idx+max_seq_length+1])
        for st_idx in start_indices]
    )

    return x_batch.to(device),y_batch.to(device)

模型保存和存储

  • 分别保存model,optimizer的state_dict(),还有保存当前的iteration(主要用于更新学习率调度)
python 复制代码
def save_checkpoint(model,optimizer,iteration,save_path):
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'iteration': iteration
    }
    torch.save(checkpoint,save_path)    
python 复制代码
def load_checkpoint(model,optimizer,filepath,device='cpu'):
    checkpoint = torch.load(filepath,map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    iteration = checkpoint['iteration']
    return iteration
相关推荐
asheuojj3 小时前
2026年GEO优化获客效果评估指南:如何精准衡量TOP5关
大数据·人工智能·python
多恩Stone3 小时前
【RoPE】Flux 中的 Image Tokenization
开发语言·人工智能·python
callJJ3 小时前
Spring AI ImageModel 完全指南:用 OpenAI DALL-E 生成图像
大数据·人工智能·spring·openai·springai·图像模型
铁蛋AI编程实战3 小时前
2026 大模型推理框架测评:vLLM 0.5/TGI 2.0/TensorRT-LLM 1.8/DeepSpeed-MII 0.9 性能与成本防线对比
人工智能·机器学习·vllm
23遇见3 小时前
CANN ops-nn 仓库高效开发指南:从入门到精通
人工智能
SAP工博科技3 小时前
SAP 公有云 ERP 多工厂多生产线数据统一管理技术实现解析
大数据·运维·人工智能
芷栀夏3 小时前
CANN ops-math:异构计算场景下基础数学算子的深度优化与硬件亲和设计解析
人工智能·cann
爱吃泡芙的小白白3 小时前
深入解析CNN中的BN层:从稳定训练到前沿演进
人工智能·神经网络·cnn·梯度爆炸·bn·稳定模型
聆风吟º3 小时前
CANN runtime 性能优化:异构计算下运行时组件的效率提升与资源利用策略
人工智能·深度学习·神经网络·cann