Adam优化器深度解析:从数学原理到PyTorch源码实

本文深入剖析深度学习中最流行的优化器Adam,从梯度下降的演进历程、Adam的数学原理、偏差修正的必要性,到完整的代码实现和调参技巧,帮你彻底理解这个"万金油"优化器。


一、优化器的演进之路

1.1 为什么需要优化器

深度学习的核心是最小化损失函数

复制代码
目标:找到参数 θ* 使得 L(θ) 最小

θ* = argmin L(θ)
        θ

方法:沿着损失函数下降最快的方向迭代更新参数

1.2 优化器家族图谱

复制代码
优化器演进路线:

                    SGD (随机梯度下降)
                         │
         ┌───────────────┼───────────────┐
         │               │               │
         ▼               ▼               ▼
    Momentum        Adagrad         NAG
    (动量加速)     (自适应学习率)   (Nesterov加速)
         │               │               │
         │               ▼               │
         │           RMSprop            │
         │          (改进Adagrad)       │
         │               │               │
         └───────────────┼───────────────┘
                         │
                         ▼
                       Adam
                (Momentum + RMSprop)
                         │
         ┌───────────────┼───────────────┐
         │               │               │
         ▼               ▼               ▼
      AdamW          NAdam          AMSGrad
   (权重衰减修正)  (Nesterov+Adam)  (修复收敛问题)

1.3 各优化器核心思想

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                      优化器核心思想对比                          │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  SGD:      θ = θ - lr × g                                       │
│            最基础,直接沿梯度方向走                              │
│                                                                 │
│  Momentum: v = β×v + g                                          │
│            θ = θ - lr × v                                       │
│            加入"惯性",加速收敛,减少震荡                       │
│                                                                 │
│  Adagrad:  累积历史梯度平方,自动调整学习率                      │
│            频繁更新的参数 → 小学习率                            │
│            稀疏更新的参数 → 大学习率                            │
│                                                                 │
│  RMSprop:  Adagrad的改进,用指数移动平均代替累积                 │
│            解决学习率单调递减的问题                              │
│                                                                 │
│  Adam:     Momentum + RMSprop                                   │
│            同时利用一阶矩(动量)和二阶矩(自适应学习率)        │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

二、从SGD到Adam的数学推导

2.1 SGD:最朴素的梯度下降

python 复制代码
# 随机梯度下降
# θ_new = θ_old - learning_rate × gradient

def sgd_update(params, grads, lr):
    """
    SGD更新规则
    
    问题:
    1. 对所有参数使用相同的学习率
    2. 容易在鞍点附近震荡
    3. 收敛速度慢
    """
    for param, grad in zip(params, grads):
        param -= lr * grad
复制代码
SGD的问题可视化:

损失曲面(椭圆形):
                    
        ╭─────────────────╮
       ╱                   ╲
      │    ↘                │
      │      ↘   震荡      │
      │        ↘↗          │
      │         ↘↗         │
      │           ★ 最优点  │
       ╲                   ╱
        ╰─────────────────╯

问题:在陡峭方向震荡,在平缓方向前进慢

2.2 Momentum:加入惯性

python 复制代码
def momentum_update(params, grads, velocities, lr, beta=0.9):
    """
    Momentum更新规则
    
    v = β × v + g          # 速度 = 保留部分旧速度 + 新梯度
    θ = θ - lr × v         # 沿速度方向更新
    
    物理直觉:小球从山上滚下,会积累速度
    - 在一致的梯度方向上加速
    - 在震荡的方向上相互抵消
    """
    for param, grad, v in zip(params, grads, velocities):
        v[:] = beta * v + grad  # 更新速度(一阶矩估计)
        param -= lr * v
复制代码
Momentum的效果:

没有Momentum:           有Momentum:
    ↘                      ↘
      ↗                      ↘
        ↘                      ↘
          ↗                      ↘
            ↘                      → ★
              ★

左边:来回震荡
右边:惯性帮助直奔目标

2.3 Adagrad:自适应学习率

python 复制代码
def adagrad_update(params, grads, cache, lr, eps=1e-8):
    """
    Adagrad更新规则
    
    cache = cache + g²            # 累积历史梯度平方
    θ = θ - lr × g / √(cache+ε)   # 学习率被历史梯度调制
    
    效果:
    - 频繁更新的参数:cache大 → 学习率小 → 稳定
    - 稀疏更新的参数:cache小 → 学习率大 → 快速学习
    
    问题:cache只增不减,学习率会趋近于0
    """
    for param, grad, c in zip(params, grads, cache):
        c[:] = c + grad ** 2
        param -= lr * grad / (np.sqrt(c) + eps)

2.4 RMSprop:改进Adagrad

python 复制代码
def rmsprop_update(params, grads, cache, lr, beta=0.999, eps=1e-8):
    """
    RMSprop更新规则
    
    cache = β × cache + (1-β) × g²    # 指数移动平均,而非累积
    θ = θ - lr × g / √(cache+ε)
    
    改进:用指数衰减代替累积
    - 近期梯度权重大
    - 远期梯度逐渐遗忘
    - 学习率不会无限减小
    """
    for param, grad, c in zip(params, grads, cache):
        c[:] = beta * c + (1 - beta) * grad ** 2
        param -= lr * grad / (np.sqrt(c) + eps)

三、Adam:集大成者

3.1 Adam的核心思想

Adam = Adaptive Moment Estimation = 自适应矩估计

它结合了两个关键技术:

  1. Momentum:利用一阶矩(梯度的指数移动平均)→ 加速收敛

  2. RMSprop:利用二阶矩(梯度平方的指数移动平均)→ 自适应学习率

    Adam的直觉:

    一阶矩 m(动量):梯度的"平均方向"

    • 告诉我们"应该往哪走"
    • 平滑梯度,减少噪声

    二阶矩 v(自适应):梯度的"波动程度"

    • 告诉我们"应该走多大步"
    • 梯度稳定 → 大步走
    • 梯度震荡 → 小步走

    结合:按照平均方向,以自适应的步长前进

3.2 Adam算法公式

复制代码
Adam完整算法:

初始化:
    m₀ = 0      # 一阶矩估计
    v₀ = 0      # 二阶矩估计
    t = 0       # 时间步

每一步更新:
    t = t + 1
    
    # 1. 计算梯度
    g_t = ∇L(θ_{t-1})
    
    # 2. 更新有偏一阶矩估计(动量)
    m_t = β₁ × m_{t-1} + (1 - β₁) × g_t
    
    # 3. 更新有偏二阶矩估计(自适应学习率)
    v_t = β₂ × v_{t-1} + (1 - β₂) × g_t²
    
    # 4. 偏差修正(关键步骤!)
    m̂_t = m_t / (1 - β₁^t)
    v̂_t = v_t / (1 - β₂^t)
    
    # 5. 更新参数
    θ_t = θ_{t-1} - α × m̂_t / (√v̂_t + ε)

默认超参数:
    α = 0.001   (学习率)
    β₁ = 0.9    (一阶矩衰减率)
    β₂ = 0.999  (二阶矩衰减率)
    ε = 1e-8    (数值稳定项)

3.3 为什么需要偏差修正?

这是Adam的关键创新之一,很多人忽略了它的重要性。

复制代码
问题:初始化时 m₀ = 0, v₀ = 0

前几步的m和v会严重偏向0(有偏估计)

数学证明:
假设梯度g的期望是E[g],方差是Var[g]

第t步的m_t期望:
E[m_t] = E[(1-β₁) × Σᵢ β₁^{t-i} × gᵢ]
       = (1-β₁) × Σᵢ β₁^{t-i} × E[g]
       = (1 - β₁^t) × E[g]    # 注意:不是 E[g]!

当t很小时,(1-β₁^t) << 1
例如 t=1, β₁=0.9 时:1-0.9¹ = 0.1
m₁ 只有真实值的 10%!

偏差修正:
m̂_t = m_t / (1 - β₁^t)
E[m̂_t] = E[m_t] / (1 - β₁^t) = E[g] ✓

现在是无偏估计了!

偏差修正的效果可视化:

                未修正的m_t          修正后的m̂_t
                    │                    │
Step 1:  ▏        (很小)         ▉▉▉▉▉▉ (接近真实)
Step 2:  ▎▏                      ▉▉▉▉▉▉
Step 5:  ▍▎▏                     ▉▉▉▉▉▉
Step 10: ▌▍▎▏                    ▉▉▉▉▉▉
Step 50: ▉▉▉▉▉▉                  ▉▉▉▉▉▉
         
随着步数增加,差异减小
但前期差异巨大,不修正会导致训练初期很不稳定

四、完整代码实现

4.1 从零实现Adam

python 复制代码
import numpy as np


class Adam:
    """
    Adam优化器的完整实现
    
    Adam: A Method for Stochastic Optimization
    https://arxiv.org/abs/1412.6980
    """
    
    def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
        """
        Args:
            lr: 学习率 α
            beta1: 一阶矩衰减率 β₁(动量)
            beta2: 二阶矩衰减率 β₂(自适应学习率)
            eps: 数值稳定项 ε
        """
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        
        # 状态变量
        self.m = {}  # 一阶矩
        self.v = {}  # 二阶矩
        self.t = 0   # 时间步
    
    def step(self, params, grads):
        """
        执行一步优化
        
        Args:
            params: 参数字典 {name: param_array}
            grads: 梯度字典 {name: grad_array}
        """
        self.t += 1
        
        for name in params:
            # 初始化矩估计
            if name not in self.m:
                self.m[name] = np.zeros_like(params[name])
                self.v[name] = np.zeros_like(params[name])
            
            g = grads[name]
            
            # 更新有偏矩估计
            self.m[name] = self.beta1 * self.m[name] + (1 - self.beta1) * g
            self.v[name] = self.beta2 * self.v[name] + (1 - self.beta2) * (g ** 2)
            
            # 偏差修正
            m_hat = self.m[name] / (1 - self.beta1 ** self.t)
            v_hat = self.v[name] / (1 - self.beta2 ** self.t)
            
            # 更新参数
            params[name] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
    
    def get_state(self):
        """获取优化器状态(用于保存检查点)"""
        return {
            'm': self.m.copy(),
            'v': self.v.copy(),
            't': self.t
        }
    
    def load_state(self, state):
        """加载优化器状态"""
        self.m = state['m']
        self.v = state['v']
        self.t = state['t']


class AdamWithWeightDecay:
    """
    AdamW: Adam with decoupled weight decay
    
    标准Adam中的L2正则化实际上不等价于权重衰减
    AdamW正确实现了权重衰减
    """
    
    def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.weight_decay = weight_decay
        
        self.m = {}
        self.v = {}
        self.t = 0
    
    def step(self, params, grads):
        """
        AdamW更新规则
        
        关键区别:权重衰减直接作用于参数,而非通过梯度
        
        Adam with L2:  g' = g + λθ, then Adam update with g'
        AdamW:         Adam update with g, then θ = θ - lr×λ×θ
        """
        self.t += 1
        
        for name in params:
            if name not in self.m:
                self.m[name] = np.zeros_like(params[name])
                self.v[name] = np.zeros_like(params[name])
            
            g = grads[name]
            
            # Adam步骤(不包含权重衰减)
            self.m[name] = self.beta1 * self.m[name] + (1 - self.beta1) * g
            self.v[name] = self.beta2 * self.v[name] + (1 - self.beta2) * (g ** 2)
            
            m_hat = self.m[name] / (1 - self.beta1 ** self.t)
            v_hat = self.v[name] / (1 - self.beta2 ** self.t)
            
            # Adam更新
            params[name] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
            
            # 解耦的权重衰减(AdamW的关键)
            params[name] -= self.lr * self.weight_decay * params[name]

4.2 PyTorch风格实现

python 复制代码
import torch
from torch.optim import Optimizer


class CustomAdam(Optimizer):
    """
    PyTorch风格的Adam实现
    """
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 
                 weight_decay=0, amsgrad=False):
        """
        Args:
            params: 模型参数
            lr: 学习率
            betas: (β₁, β₂) 矩估计的衰减率
            eps: 数值稳定项
            weight_decay: 权重衰减(L2惩罚)
            amsgrad: 是否使用AMSGrad变体
        """
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if eps < 0.0:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        
        defaults = dict(lr=lr, betas=betas, eps=eps, 
                       weight_decay=weight_decay, amsgrad=amsgrad)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self, closure=None):
        """
        执行单步优化
        
        Args:
            closure: 重新计算损失的闭包(可选)
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad
                
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients')
                
                amsgrad = group['amsgrad']
                
                # 获取状态
                state = self.state[p]
                
                # 初始化状态
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)      # m
                    state['exp_avg_sq'] = torch.zeros_like(p)   # v
                    if amsgrad:
                        state['max_exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                
                state['step'] += 1
                
                # 权重衰减(L2正则化方式)
                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])
                
                # 更新一阶矩估计
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                
                # 更新二阶矩估计
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                if amsgrad:
                    # AMSGrad: 使用历史最大的v
                    torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                
                # 偏差修正
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                step_size = group['lr'] * (bias_correction2 ** 0.5) / bias_correction1
                
                # 更新参数
                p.addcdiv_(exp_avg, denom, value=-step_size)
        
        return loss


class CustomAdamW(Optimizer):
    """
    AdamW: 解耦权重衰减的Adam
    
    与标准Adam的区别:
    - Adam: 权重衰减通过梯度实现 g' = g + λθ
    - AdamW: 权重衰减直接作用于参数 θ' = θ - lr×λ×θ
    
    这个区别在自适应学习率优化器中很重要!
    """
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad
                state = self.state[p]
                
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                state['step'] += 1
                
                # 更新矩估计(注意:不包含权重衰减!)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # 偏差修正
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                step_size = group['lr'] * (bias_correction2 ** 0.5) / bias_correction1
                
                # Adam更新
                denom = exp_avg_sq.sqrt().add_(group['eps'])
                p.addcdiv_(exp_avg, denom, value=-step_size)
                
                # 解耦的权重衰减(关键区别!)
                p.add_(p, alpha=-group['lr'] * group['weight_decay'])
        
        return loss

4.3 带学习率调度的Adam

python 复制代码
class AdamWithScheduler:
    """
    带学习率调度的Adam
    
    常用调度策略:
    1. Step Decay: 每N步衰减一次
    2. Cosine Annealing: 余弦退火
    3. Warmup: 预热阶段线性增加学习率
    4. Linear Decay: 线性衰减
    """
    
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 warmup_steps=1000, total_steps=100000, min_lr=1e-6):
        self.base_lr = lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        
        self.optimizer = torch.optim.Adam(params, lr=lr, betas=betas, eps=eps)
        self.current_step = 0
    
    def get_lr(self):
        """
        计算当前学习率
        
        Warmup + Cosine Annealing
        """
        if self.current_step < self.warmup_steps:
            # 线性预热
            return self.base_lr * self.current_step / self.warmup_steps
        else:
            # 余弦退火
            progress = (self.current_step - self.warmup_steps) / \
                      (self.total_steps - self.warmup_steps)
            return self.min_lr + 0.5 * (self.base_lr - self.min_lr) * \
                   (1 + np.cos(np.pi * progress))
    
    def step(self):
        """执行一步优化"""
        # 更新学习率
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        # 执行优化
        self.optimizer.step()
        self.current_step += 1
        
        return lr
    
    def zero_grad(self):
        self.optimizer.zero_grad()


def visualize_lr_schedule():
    """可视化学习率调度"""
    import matplotlib.pyplot as plt
    
    warmup_steps = 1000
    total_steps = 10000
    base_lr = 1e-3
    min_lr = 1e-6
    
    lrs = []
    for step in range(total_steps):
        if step < warmup_steps:
            lr = base_lr * step / warmup_steps
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + np.cos(np.pi * progress))
        lrs.append(lr)
    
    plt.figure(figsize=(10, 4))
    plt.plot(lrs)
    plt.xlabel('Step')
    plt.ylabel('Learning Rate')
    plt.title('Warmup + Cosine Annealing Schedule')
    plt.axvline(x=warmup_steps, color='r', linestyle='--', label='Warmup End')
    plt.legend()
    plt.savefig('lr_schedule.png', dpi=150)
    print("Learning rate schedule saved to lr_schedule.png")

五、Adam的变体与改进

5.1 AMSGrad:修复收敛问题

python 复制代码
"""
AMSGrad: 解决Adam的非收敛问题

问题:Adam在某些情况下不能保证收敛到最优解
原因:v_t可能会减小,导致学习率增大

解决:使用历史最大的v_t

v̂_t = max(v̂_{t-1}, v_t)

保证学习率单调不增,确保收敛
"""

class AMSGrad(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        defaults = dict(lr=lr, betas=betas, eps=eps)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                state = self.state[p]
                
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    state['max_exp_avg_sq'] = torch.zeros_like(p)  # AMSGrad特有
                
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                max_exp_avg_sq = state['max_exp_avg_sq']
                
                state['step'] += 1
                
                grad = p.grad
                
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # AMSGrad的关键:取历史最大值
                torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                step_size = group['lr'] * (bias_correction2 ** 0.5) / bias_correction1
                
                # 使用max_exp_avg_sq而非exp_avg_sq
                denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                p.addcdiv_(exp_avg, denom, value=-step_size)

5.2 NAdam:加入Nesterov动量

python 复制代码
"""
NAdam: Nesterov-accelerated Adam

将Nesterov动量整合到Adam中

Nesterov的核心思想:
先按动量方向走一步,再计算梯度
"向前看",预测未来的梯度

NAdam将这个思想应用于Adam的一阶矩
"""

class NAdam(Optimizer):
    def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                state = self.state[p]
                
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                state['step'] += 1
                
                grad = p.grad
                
                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])
                
                # 更新矩估计
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                # 偏差修正
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                # NAdam的关键:Nesterov风格的一阶矩
                # 不使用 m̂_t,而是使用 β₁m̂_t + (1-β₁)g_t / bias_correction1
                m_hat = exp_avg / bias_correction1
                g_hat = grad / bias_correction1
                
                # Nesterov修正
                nesterov_m = beta1 * m_hat + (1 - beta1) * g_hat
                
                v_hat = exp_avg_sq / bias_correction2
                denom = v_hat.sqrt().add_(group['eps'])
                
                p.addcdiv_(nesterov_m, denom, value=-group['lr'])

5.3 RAdam:自适应学习率方差修正

python 复制代码
"""
RAdam: Rectified Adam

问题:Adam初期方差大,需要warmup

RAdam自动处理这个问题:
- 自动计算方差修正项
- 方差大时退化为SGD with Momentum
- 方差小时变为完整Adam
"""

class RAdam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                state = self.state[p]
                
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                state['step'] += 1
                
                grad = p.grad
                
                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])
                
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                bias_correction1 = 1 - beta1 ** state['step']
                
                # RAdam的关键:计算最大方差长度
                rho_inf = 2 / (1 - beta2) - 1
                rho_t = rho_inf - 2 * state['step'] * (beta2 ** state['step']) / bias_correction1
                
                if rho_t > 5:
                    # 方差足够小,使用完整Adam
                    bias_correction2 = 1 - beta2 ** state['step']
                    
                    # 方差修正项
                    rect = np.sqrt(
                        (rho_t - 4) * (rho_t - 2) * rho_inf / 
                        ((rho_inf - 4) * (rho_inf - 2) * rho_t)
                    )
                    
                    m_hat = exp_avg / bias_correction1
                    v_hat = exp_avg_sq / bias_correction2
                    denom = v_hat.sqrt().add_(group['eps'])
                    
                    p.addcdiv_(m_hat, denom, value=-group['lr'] * rect)
                else:
                    # 方差太大,退化为SGD with Momentum
                    m_hat = exp_avg / bias_correction1
                    p.add_(m_hat, alpha=-group['lr'])

5.4 AdaFactor:内存高效

python 复制代码
"""
AdaFactor: 内存高效的Adam替代品

问题:Adam需要存储m和v,每个参数需要3倍内存

AdaFactor的解决方案:
- 对于矩阵参数,用行和列的统计量近似完整的v
- 内存从O(n×m)降到O(n+m)
"""

class AdaFactor(Optimizer):
    """
    简化版AdaFactor实现
    
    实际中推荐使用fairseq或transformers库的实现
    """
    
    def __init__(self, params, lr=None, eps=(1e-30, 1e-3), 
                 clip_threshold=1.0, decay_rate=-0.8,
                 beta1=None, weight_decay=0.0, scale_parameter=True,
                 relative_step=True, warmup_init=False):
        
        defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold,
                       decay_rate=decay_rate, beta1=beta1,
                       weight_decay=weight_decay, scale_parameter=scale_parameter,
                       relative_step=relative_step, warmup_init=warmup_init)
        super().__init__(params, defaults)
    
    def _get_lr(self, param_group, param_state):
        """自适应学习率"""
        if param_group['relative_step']:
            min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
            lr = min(min_step, 1.0 / np.sqrt(param_state['step']))
        else:
            lr = param_group['lr']
        
        if param_group['scale_parameter']:
            lr *= max(1e-3, param_state['RMS'])
        
        return lr
    
    def _rms(self, tensor):
        """计算RMS"""
        return tensor.norm(2) / (tensor.numel() ** 0.5)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad
                state = self.state[p]
                
                grad_shape = grad.shape
                factored = len(grad_shape) >= 2
                
                if len(state) == 0:
                    state['step'] = 0
                    state['RMS'] = 0
                    
                    if factored:
                        # 分解存储:行和列
                        state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1])
                        state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:])
                    else:
                        state['exp_avg_sq'] = torch.zeros_like(grad)
                    
                    if group['beta1'] is not None:
                        state['exp_avg'] = torch.zeros_like(grad)
                
                state['step'] += 1
                state['RMS'] = self._rms(p)
                lr = self._get_lr(group, state)
                
                # 更新二阶矩估计
                decay_rate = group['decay_rate']
                rho = min(lr, 1 / state['step'])
                
                if factored:
                    # 分解更新
                    exp_avg_sq_row = state['exp_avg_sq_row']
                    exp_avg_sq_col = state['exp_avg_sq_col']
                    
                    exp_avg_sq_row.mul_(1 - rho).add_(
                        (grad ** 2).mean(dim=-1), alpha=rho
                    )
                    exp_avg_sq_col.mul_(1 - rho).add_(
                        (grad ** 2).mean(dim=-2), alpha=rho
                    )
                    
                    # 重构完整的v
                    row_col_mean = exp_avg_sq_row.mean(dim=-1, keepdim=True)
                    v = exp_avg_sq_row.unsqueeze(-1) * exp_avg_sq_col.unsqueeze(-2) / row_col_mean.unsqueeze(-1)
                else:
                    exp_avg_sq = state['exp_avg_sq']
                    exp_avg_sq.mul_(1 - rho).add_(grad ** 2, alpha=rho)
                    v = exp_avg_sq
                
                # 更新
                update = grad / (v.sqrt() + group['eps'][0])
                update.mul_(lr)
                
                if group['beta1'] is not None:
                    exp_avg = state['exp_avg']
                    exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
                    update = exp_avg
                
                p.add_(update, alpha=-1)

六、Adam调参指南

6.1 超参数选择

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    Adam超参数调参指南                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  学习率 (lr):                                                   │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ 默认值: 1e-3 到 3e-4                                     │   │
│  │ CV任务: 1e-4 到 1e-3                                     │   │
│  │ NLP任务: 1e-5 到 5e-5 (尤其是微调预训练模型)             │   │
│  │ GAN: 1e-4 到 2e-4                                        │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  β₁ (一阶矩衰减):                                               │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ 默认值: 0.9                                              │   │
│  │ 较小值 (0.5-0.8): 快速适应新梯度                        │   │
│  │ 较大值 (0.9-0.99): 更平滑的更新                         │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  β₂ (二阶矩衰减):                                               │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ 默认值: 0.999                                            │   │
│  │ 稀疏梯度: 0.99 或更小                                    │   │
│  │ 更稳定: 0.999 或 0.9999                                  │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  ε (数值稳定项):                                                │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ 默认值: 1e-8                                             │   │
│  │ 混合精度训练: 1e-4 到 1e-6                               │   │
│  │ 数值不稳定时: 增大到 1e-4                                │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  weight_decay:                                                  │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ 默认值: 0 到 0.01                                        │   │
│  │ 防止过拟合: 0.01 到 0.1                                  │   │
│  │ AdamW: 推荐0.01                                          │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

6.2 不同任务的推荐配置

python 复制代码
# ==================== 图像分类 ====================
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    weight_decay=0.05
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=100, eta_min=1e-6
)


# ==================== NLP/Transformer微调 ====================
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=2e-5,
    betas=(0.9, 0.999),
    eps=1e-6,
    weight_decay=0.01
)

# 带warmup的调度
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=1000,
    num_training_steps=100000
)


# ==================== GAN训练 ====================
# Generator
g_optimizer = torch.optim.Adam(
    generator.parameters(),
    lr=1e-4,
    betas=(0.5, 0.999)  # 注意β₁较小
)

# Discriminator
d_optimizer = torch.optim.Adam(
    discriminator.parameters(),
    lr=1e-4,
    betas=(0.5, 0.999)
)


# ==================== 强化学习 ====================
optimizer = torch.optim.Adam(
    policy.parameters(),
    lr=3e-4,
    betas=(0.9, 0.999),
    eps=1e-5  # RL中常用较大的eps
)


# ==================== 从头训练大模型 ====================
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.95),  # β₂略小
    weight_decay=0.1
)

# 带warmup的余弦调度
def lr_lambda(step):
    warmup_steps = 2000
    total_steps = 100000
    
    if step < warmup_steps:
        return step / warmup_steps
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1 + np.cos(np.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

6.3 调试技巧

python 复制代码
class AdamDebugger:
    """
    Adam调试工具
    
    帮助诊断训练问题
    """
    
    def __init__(self, optimizer):
        self.optimizer = optimizer
        self.history = {
            'lr': [],
            'grad_norm': [],
            'm_norm': [],
            'v_norm': [],
            'update_norm': []
        }
    
    def log_step(self):
        """记录每步的统计信息"""
        total_grad_norm = 0
        total_m_norm = 0
        total_v_norm = 0
        total_update_norm = 0
        
        for group in self.optimizer.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                state = self.optimizer.state[p]
                
                total_grad_norm += p.grad.norm().item() ** 2
                
                if 'exp_avg' in state:
                    total_m_norm += state['exp_avg'].norm().item() ** 2
                    total_v_norm += state['exp_avg_sq'].norm().item() ** 2
        
        self.history['lr'].append(group['lr'])
        self.history['grad_norm'].append(np.sqrt(total_grad_norm))
        self.history['m_norm'].append(np.sqrt(total_m_norm))
        self.history['v_norm'].append(np.sqrt(total_v_norm))
    
    def diagnose(self):
        """诊断训练问题"""
        issues = []
        
        # 检查梯度爆炸
        recent_grad_norm = np.mean(self.history['grad_norm'][-100:])
        if recent_grad_norm > 100:
            issues.append("⚠️ 梯度范数过大,可能发生梯度爆炸")
            issues.append("   建议:使用梯度裁剪 torch.nn.utils.clip_grad_norm_")
        
        # 检查梯度消失
        if recent_grad_norm < 1e-7:
            issues.append("⚠️ 梯度范数过小,可能发生梯度消失")
            issues.append("   建议:检查网络结构,使用残差连接")
        
        # 检查学习率
        if len(self.history['lr']) > 0:
            current_lr = self.history['lr'][-1]
            if current_lr < 1e-8:
                issues.append("⚠️ 学习率过小")
        
        # 检查动量
        if len(self.history['m_norm']) > 100:
            m_trend = np.polyfit(range(100), self.history['m_norm'][-100:], 1)[0]
            if m_trend > 0.1:
                issues.append("⚠️ 动量持续增大,可能训练不稳定")
        
        if not issues:
            issues.append("✓ 训练看起来正常")
        
        return issues
    
    def plot(self, save_path='adam_debug.png'):
        """可视化训练曲线"""
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        axes[0, 0].plot(self.history['grad_norm'])
        axes[0, 0].set_title('Gradient Norm')
        axes[0, 0].set_yscale('log')
        
        axes[0, 1].plot(self.history['lr'])
        axes[0, 1].set_title('Learning Rate')
        
        axes[1, 0].plot(self.history['m_norm'], label='m (momentum)')
        axes[1, 0].set_title('Momentum Norm')
        axes[1, 0].legend()
        
        axes[1, 1].plot(self.history['v_norm'])
        axes[1, 1].set_title('v (adaptive lr) Norm')
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=150)
        print(f"Saved debug plot to {save_path}")

七、Adam vs SGD:如何选择

7.1 对比分析

复制代码
┌─────────────────────────────────────────────────────────────────┐
│                    Adam vs SGD 对比                              │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  收敛速度:                                                     │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ Adam: ★★★★★ 快,尤其是训练初期                          │   │
│  │ SGD:  ★★★☆☆ 慢,需要仔细调学习率                       │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  泛化性能:                                                     │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ Adam: ★★★☆☆ 可能略差于调好的SGD                         │   │
│  │ SGD:  ★★★★☆ 通常泛化更好(有争议)                      │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  调参难度:                                                     │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ Adam: ★★☆☆☆ 简单,默认参数通常就能用                    │   │
│  │ SGD:  ★★★★☆ 困难,需要调学习率、动量、调度器            │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  内存占用:                                                     │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │ Adam: 3× 参数量 (θ, m, v)                                │   │
│  │ SGD:  1× 参数量 (θ) 或 2× (带动量)                       │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

7.2 选择建议

python 复制代码
"""
选择优化器的经验法则:

1. 默认选Adam/AdamW
   - 快速原型开发
   - 不确定用什么时
   - NLP任务(尤其是Transformer)
   - GAN、VAE等生成模型
   - 强化学习

2. 选择SGD with Momentum
   - 追求最佳泛化性能(如ImageNet竞赛)
   - 有足够时间调参
   - 训练ResNet等经典CNN
   - 内存受限的大模型

3. 混合策略
   - 先用Adam快速收敛
   - 再切换到SGD精调
   - 获得两者的优点
"""

def hybrid_training(model, train_loader, epochs_adam=50, epochs_sgd=50):
    """混合训练策略"""
    
    # 阶段1:Adam快速收敛
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    
    for epoch in range(epochs_adam):
        train_epoch(model, train_loader, optimizer)
    
    print("Switching to SGD...")
    
    # 阶段2:SGD精调
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=0.01,  # 通常比Adam大
        momentum=0.9,
        weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs_sgd
    )
    
    for epoch in range(epochs_sgd):
        train_epoch(model, train_loader, optimizer)
        scheduler.step()

八、总结

8.1 Adam的核心要点

复制代码
Adam = Momentum + RMSprop + 偏差修正

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│  m_t = β₁×m_{t-1} + (1-β₁)×g_t     ← 一阶矩(动量)        │
│  v_t = β₂×v_{t-1} + (1-β₂)×g_t²    ← 二阶矩(自适应学习率)│
│                                                             │
│  m̂_t = m_t / (1-β₁^t)              ← 偏差修正              │
│  v̂_t = v_t / (1-β₂^t)                                      │
│                                                             │
│  θ_t = θ_{t-1} - α×m̂_t/√(v̂_t+ε)   ← 更新                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

8.2 关键变体

变体 核心改进 适用场景
AdamW 解耦权重衰减 几乎所有场景(推荐默认使用)
AMSGrad 保证收敛 理论保证需求
NAdam Nesterov动量 需要更快收敛
RAdam 自动warmup 不想调warmup参数
AdaFactor 内存高效 超大模型

8.3 一句话总结

Adam是深度学习的"万金油"优化器:自适应学习率让调参变简单,动量让收敛变快,偏差修正让训练初期更稳定。

希望这篇文章帮助你深入理解了Adam优化器!如有问题,欢迎评论区交流。


参考文献

  1. Kingma D, Ba J. "Adam: A Method for Stochastic Optimization." ICLR 2015.
  2. Loshchilov I, Hutter F. "Decoupled Weight Decay Regularization." ICLR 2019.
  3. Reddi S J, et al. "On the Convergence of Adam and Beyond." ICLR 2018.
  4. Liu L, et al. "On the Variance of the Adaptive Learning Rate and Beyond." ICLR 2020.

作者:Jia

更多技术文章,欢迎关注我的CSDN博客!

相关推荐
Blossom.1182 小时前
AI Agent智能办公助手:从ChatGPT到真正“干活“的系统
人工智能·分布式·python·深度学习·神经网络·chatgpt·迁移学习
a努力。2 小时前
2026 AI 编程终极套装:Claude Code + Codex + Gemini CLI + Antigravity,四位一体实战指南!
java·开发语言·人工智能·分布式·python·面试
qwerasda1238522 小时前
基于cornernet_hourglass104的纸杯检测与识别模型训练与优化详解
人工智能·计算机视觉·目标跟踪
梦茹^_^2 小时前
flask框架(笔记一次性写完)
redis·python·flask·cookie·session
二川bro2 小时前
Java集合类框架的基本接口有哪些?
java·开发语言·python
抠头专注python环境配置2 小时前
解决“No module named ‘tensorflow‘”报错:从导入失败到环境配置成功
人工智能·windows·python·tensorflow·neo4j
zhangfeng11332 小时前
PowerShell 中不支持激活你选中的 Python 虚拟环境,建议切换到命令提示符(Command Prompt)
开发语言·python·prompt
好奇龙猫2 小时前
【AI学习-comfyUI学习-三十六节-黑森林-融合+扩图工作流-各个部分学习】
人工智能·学习
qh0526wy2 小时前
WINDOWS BAT 开机登录后自动启动
windows·python