大模型训练 Learning rate warmup, cosine decay and gradient clipping

1. 学习率Warm up

在训练复杂的模型时,使用学习率热身可以帮助训练稳定。在学习率热身中,我们逐渐增加学习率,从一个非常低的值inital_lr逐渐到用户定义的最大学习率peak_lr

python 复制代码
n_epochs = 15
initial_lr = 0.0001
peak_lr = 0.01

total_steps = len(train_loader) * n_epochs
warmup_steps = int(0.2 * total_steps) # 20% warmup
print(warmup_steps)

2. 余弦退火

在达到最高学习率后,不断降低到min_lr,这是通过余弦函数来实现的,最开始的余弦函数是cos0=1,到最后是cospi = -1,随着迭代次数增加,学习率慢慢递减。

python3 复制代码
import math

min_lr = 0.1 * initial_lr
track_lrs = []

lr_increment = (peak_lr - initial_lr) / warmup_steps
global_step = -1

for epoch in range(n_epochs):
    for input_batch, target_batch in train_loader:
        optimizer.zero_grad()
        global_step += 1
    
        # Adjust the learning rate based on the current phase (warmup or cosine annealing)
        if global_step < warmup_steps:
            # Linear warmup
            lr = initial_lr + global_step * lr_increment  
        else:
            # Cosine annealing after warmup
            progress = ((global_step - warmup_steps) / 
                        (total_training_steps - warmup_steps))
            lr = min_lr + (peak_lr - min_lr) * 0.5 * (
                1 + math.cos(math.pi * progress))
        
        # Apply the calculated learning rate to the optimizer
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        track_lrs.append(optimizer.param_groups[0]["lr"])
    
        # Calculate loss and update weights

3. 梯度裁剪

python3 复制代码
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

使用clip_grad_norm可以根据L2函数,将梯度的L2范数裁剪到max_norm,方法是直接除。

相关推荐
boonya7 分钟前
从阿里云大模型服务平台百炼看AI应用集成与实践
人工智能·阿里云·云计算
amhjdx10 分钟前
三维技术 + AI 动画,焕活古镇科技人文新表达,天南文化助力 2025 年世界互联网大会乌镇峰会
人工智能·科技
鹿子沐19 分钟前
LLamaFactory模型导出量化
人工智能·语言模型
skywalk816322 分钟前
尝试Auto-coder.chat使用星河社区AIStudio部署的几个大模型:文心4.5-21b、Deepseek r1 70b、llama 3.1 8b
linux·服务器·人工智能·大模型·aistudio
鹿子沐26 分钟前
LlamaFactory微调效果与vllm部署效果不一致
人工智能·llama
Akamai中国1 小时前
AI 边缘计算:决胜未来
人工智能·云计算·边缘计算·云服务
陈增林1 小时前
基于PyQt5的AI文档处理工具
人工智能
BeingACoder1 小时前
【SAA】SpringAI Alibaba学习笔记(二):提示词Prompt
java·人工智能·spring boot·笔记·prompt·saa·springai
Acrelhuang1 小时前
覆盖全场景需求:Acrel-1000 变电站综合自动化系统的技术亮点与应用
大数据·网络·人工智能·笔记·物联网
LHZSMASH!1 小时前
神经流形:大脑功能几何基础的革命性视角
人工智能·深度学习·神经网络·机器学习