从 0 实现一个 Offline RL 算法 (以 IQL 为例)

摘要

纸上得来终觉浅,绝知此事要躬行。看懂了论文公式,不代表能写对代码。在 Offline RL 中,数据处理的细节网络初始化的技巧 以及Loss 的计算顺序,往往比算法原理本身更能决定成败。本文将带你从零构建一个完整的 IQL 训练流程,涵盖 D4RL 数据加载、归一化处理、核心 Loss 实现以及工业级的训练 Trick。


目录

  1. 准备工作:数据加载与归一化
  2. [网络架构:V, Q 与 Policy](#网络架构:V, Q 与 Policy)
  3. [核心逻辑:IQL 的三个 Loss](#核心逻辑:IQL 的三个 Loss)
  4. [完整的 Update Step 代码](#完整的 Update Step 代码)
  5. 稳定训练的工程技巧 (Tricks)
  6. [常见 Bug 与排查方法](#常见 Bug 与排查方法)

1. 准备工作:数据加载与归一化

这是 Offline RL 中最重要的一步! 90% 的失败案例都是因为没有对 State 进行归一化。

1.1 加载 D4RL

首先你需要安装 d4rl。D4RL 的数据集通常包含 observations, actions, rewards, terminals 等字段。

1.2 标准化 (Normalization)

由于 State 的不同维度可能有巨大的数值差异(例如位置坐标是 100,而速度是 0.01),直接训练会导致梯度爆炸或收敛极慢。我们必须把 State 归一化到 均值为 0,方差为 1

python 复制代码
import torch
import numpy as np
import d4rl
import gym

def get_dataset(env):
    dataset = d4rl.qlearning_dataset(env)
    
    # 转换为 Tensor
    states = torch.from_numpy(dataset['observations']).float()
    actions = torch.from_numpy(dataset['actions']).float()
    rewards = torch.from_numpy(dataset['rewards']).float()
    next_states = torch.from_numpy(dataset['next_observations']).float()
    dones = torch.from_numpy(dataset['terminals']).float()

    return states, actions, rewards, next_states, dones

def normalize_states(states, next_states):
    # 计算统计量
    mean = states.mean(dim=0, keepdim=True)
    std = states.std(dim=0, keepdim=True) + 1e-3 # 防止除零
    
    # 归一化
    states = (states - mean) / std
    next_states = (next_states - mean) / std
    
    return states, next_states, mean, std

2. 网络架构:V, Q 与 Policy

IQL 需要三个网络:

  1. Q Network (Twin) :评估 ( s , a ) (s, a) (s,a) 的价值。为了稳定,通常用两个 Q 网络 ( Q 1 , Q 2 Q_1, Q_2 Q1,Q2)。
  2. V Network :评估状态 s s s 的价值(作为 Expectile)。
  3. Policy Network:输出动作分布(通常是 Gaussian)。
python 复制代码
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.net(x)

# 策略网络通常输出均值和方差
class GaussianPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        self.mu = nn.Linear(256, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim)) # 可学习的 log_std

    def forward(self, state):
        x = self.net(state)
        mu = self.mu(x)
        # 限制 log_std 范围,防止方差过大或过小(关键 Trick)
        log_std = torch.clamp(self.log_std, -20, 2) 
        std = torch.exp(log_std)
        return torch.distributions.Normal(mu, std)
    
    def get_action(self, state, deterministic=False):
        dist = self.forward(state)
        if deterministic:
            return torch.tanh(dist.mean) # 测试时用均值
        return torch.tanh(dist.sample()) # 训练时采样

3. 核心逻辑:IQL 的三个 Loss

IQL 的核心是非对称的 Expectile Loss

python 复制代码
def expectile_loss(diff, expectile=0.7):
    # diff = Q - V
    # 当 Q > V 时 (diff > 0),权重为 expectile (比如 0.7)
    # 当 Q < V 时 (diff < 0),权重为 1-expectile (比如 0.3)
    # 这会使 V 倾向于靠近 Q 分布的上边缘
    weight = torch.where(diff > 0, expectile, (1 - expectile))
    return torch.mean(weight * (diff ** 2))

4. 完整的 Update Step 代码

将所有组件拼装起来。注意 Target Network 的使用和梯度的阻断。

python 复制代码
class IQL_Agent:
    def __init__(self, state_dim, action_dim, device):
        self.q1 = MLP(state_dim + action_dim, 1).to(device)
        self.q2 = MLP(state_dim + action_dim, 1).to(device)
        self.target_q1 = copy.deepcopy(self.q1) # Target Q用于稳定训练
        self.target_q2 = copy.deepcopy(self.q2)
        
        self.v = MLP(state_dim, 1).to(device)
        self.actor = GaussianPolicy(state_dim, action_dim).to(device)
        
        # 优化器
        self.q_optimizer = torch.optim.Adam(list(self.q1.parameters()) + list(self.q2.parameters()), lr=3e-4)
        self.v_optimizer = torch.optim.Adam(self.v.parameters(), lr=3e-4)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        
        self.expectile = 0.7  # IQL 核心超参
        self.temperature = 3.0 # AWR 核心超参
        self.gamma = 0.99
        self.tau = 0.005 # 软更新系数

    def update(self, batch):
        states, actions, rewards, next_states, dones = batch
        
        # ---------------------------------------
        # 1. Update V (Expectile Regression)
        # ---------------------------------------
        with torch.no_grad():
            # 使用 Target Q 来计算 V 的目标,更稳定
            q1_t = self.target_q1(torch.cat([states, actions], dim=1))
            q2_t = self.target_q2(torch.cat([states, actions], dim=1))
            min_q = torch.min(q1_t, q2_t)
            
        v_pred = self.v(states)
        v_loss = expectile_loss(min_q - v_pred, self.expectile)
        
        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()
        
        # ---------------------------------------
        # 2. Update Q (MSE Loss)
        # ---------------------------------------
        with torch.no_grad():
            next_v = self.v(next_states)
            # 关键:IQL 的 Q target 使用 V(s'),不使用 max Q(s', a')
            q_target = rewards + self.gamma * (1 - dones) * next_v

        q1_pred = self.q1(torch.cat([states, actions], dim=1))
        q2_pred = self.q2(torch.cat([states, actions], dim=1))
        q_loss = F.mse_loss(q1_pred, q_target) + F.mse_loss(q2_pred, q_target)
        
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        # ---------------------------------------
        # 3. Update Policy (Advantage Weighted Regression)
        # ---------------------------------------
        with torch.no_grad():
            # 计算优势函数 A(s, a) = Q(s, a) - V(s)
            q1 = self.target_q1(torch.cat([states, actions], dim=1))
            q2 = self.target_q2(torch.cat([states, actions], dim=1))
            min_q = torch.min(q1, q2)
            v = self.v(states)
            advantage = min_q - v
            
            # 计算权重 exp(A / T)
            exp_adv = torch.exp(advantage / self.temperature)
            # 限制权重上限,防止数值不稳定
            exp_adv = torch.clamp(exp_adv, max=100.0)

        # 计算 Policy 的 log_prob(a|s)
        dist = self.actor(states)
        log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True)
        
        # Loss = - weights * log_prob (加权最大似然)
        actor_loss = -(exp_adv * log_prob).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # ---------------------------------------
        # 4. Soft Update Target Networks
        # ---------------------------------------
        self.soft_update(self.q1, self.target_q1)
        self.soft_update(self.q2, self.target_q2)

    def soft_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)

5. 稳定训练的工程技巧 (Tricks)

如果只写上面的代码,你可能只能在简单任务上跑通。想在 AntMaze 上拿分,还需要以下 Trick:

  1. Cosine Learning Rate Decay
    Offline RL 容易过拟合。在训练最后阶段将学习率衰减到 0,能显著提升测试性能。

    python 复制代码
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps)
  2. LayerNorm
    在 MLP 的 ReLU 之前加入 nn.LayerNorm(),对于防止 Q 值发散非常有用。

  3. Orthogonal Initialization
    使用正交初始化网络参数,比默认的 Xavier 初始化收敛更快。

  4. Target Update 频率
    IQL 中 V 的更新很快,Q 的 Target Network 更新可以适当慢一点,或者不使用 Target Network 直接用当前的 Q 也可以(IQL 论文中有些变体是这样做的),但保留 Target Q 通常更稳。


6. 常见 Bug 与排查方法 🛠️

6.1 Q Loss 不下降 / 震荡

  • 原因:State 没有归一化。
  • 排查:打印 State 的 mean 和 std,如果 mean 不是 0 附近,必挂。

6.2 Policy Loss 变成 NaN

  • 原因exp(advantage / temperature) 溢出。
  • 排查 :检查 Advantage 的数值范围。如果 A 很大(比如 100),exp(30) 就会很大。一定要加 torch.clamp

6.3 训练出来的 Agent 一动不动

  • 原因:Temperature 太小,或者 Expectile 太大。
  • 排查
    • 如果 temperature 太小(如 0.1),Policy 只会模仿那些极少数 Advantage 极大的样本,导致过拟合。
    • 如果 expectile 太大(如 0.99),V 值会估计得非常高,导致 Advantage 几乎全是负的,Policy 学不到东西。推荐默认值:Expectile=0.7, Temperature=3.0

6.4 测试分数极低,但 Q 值很高

  • 原因:Overestimation(尽管 IQL 已经很克制了,但依然可能发生)。
  • 排查:IQL 的 Q 值不应该特别大。如果发现 Q 值远超 Max Episode Return,说明 Target 计算有问题,或者 Reward Scale 太大(建议把 Reward 归一化到 [0, 1] 或做简单的 Scaling)。

结语

从零实现 Offline RL 是一个痛苦但收益巨大的过程。你会发现它不再是黑盒,而是由一个个精巧的积木(Expectile, AWR, Normalization)搭建的城堡。

现在的你,已经具备了手写 SOTA 算法的能力,去 D4RL 榜单上试试身手吧!

相关推荐
风象南7 小时前
普通人用AI加持赚到的第一个100块
人工智能·后端
牛奶8 小时前
2026年大模型怎么选?前端人实用对比
前端·人工智能·ai编程
牛奶8 小时前
前端人为什么要学AI?
前端·人工智能·ai编程
哥布林学者9 小时前
高光谱成像(一)高光谱图像
机器学习·高光谱成像
罗西的思考11 小时前
AI Agent框架探秘:拆解 OpenHands(10)--- Runtime
人工智能·算法·机器学习
冬奇Lab11 小时前
OpenClaw 源码精读(2):Channel & Routing——一条消息如何找到它的 Agent?
人工智能·开源·源码阅读
冬奇Lab11 小时前
一天一个开源项目(第38篇):Claude Code Telegram - 用 Telegram 远程用 Claude Code,随时随地聊项目
人工智能·开源·资讯
格砸13 小时前
从入门到辞职|从ChatGPT到OpenClaw,跟上智能时代的进化
前端·人工智能·后端
可观测性用观测云13 小时前
可观测性 4.0:教系统如何思考
人工智能
sunny86513 小时前
Claude Code 跨会话上下文恢复:从 8 次纠正到 0 次的工程实践
人工智能·开源·github