从 0 实现一个 Offline RL 算法 (以 IQL 为例)

摘要

纸上得来终觉浅,绝知此事要躬行。看懂了论文公式,不代表能写对代码。在 Offline RL 中,数据处理的细节网络初始化的技巧 以及Loss 的计算顺序,往往比算法原理本身更能决定成败。本文将带你从零构建一个完整的 IQL 训练流程,涵盖 D4RL 数据加载、归一化处理、核心 Loss 实现以及工业级的训练 Trick。


目录

  1. 准备工作:数据加载与归一化
  2. [网络架构:V, Q 与 Policy](#网络架构:V, Q 与 Policy)
  3. [核心逻辑:IQL 的三个 Loss](#核心逻辑:IQL 的三个 Loss)
  4. [完整的 Update Step 代码](#完整的 Update Step 代码)
  5. 稳定训练的工程技巧 (Tricks)
  6. [常见 Bug 与排查方法](#常见 Bug 与排查方法)

1. 准备工作:数据加载与归一化

这是 Offline RL 中最重要的一步! 90% 的失败案例都是因为没有对 State 进行归一化。

1.1 加载 D4RL

首先你需要安装 d4rl。D4RL 的数据集通常包含 observations, actions, rewards, terminals 等字段。

1.2 标准化 (Normalization)

由于 State 的不同维度可能有巨大的数值差异(例如位置坐标是 100,而速度是 0.01),直接训练会导致梯度爆炸或收敛极慢。我们必须把 State 归一化到 均值为 0,方差为 1

python 复制代码
import torch
import numpy as np
import d4rl
import gym

def get_dataset(env):
    dataset = d4rl.qlearning_dataset(env)
    
    # 转换为 Tensor
    states = torch.from_numpy(dataset['observations']).float()
    actions = torch.from_numpy(dataset['actions']).float()
    rewards = torch.from_numpy(dataset['rewards']).float()
    next_states = torch.from_numpy(dataset['next_observations']).float()
    dones = torch.from_numpy(dataset['terminals']).float()

    return states, actions, rewards, next_states, dones

def normalize_states(states, next_states):
    # 计算统计量
    mean = states.mean(dim=0, keepdim=True)
    std = states.std(dim=0, keepdim=True) + 1e-3 # 防止除零
    
    # 归一化
    states = (states - mean) / std
    next_states = (next_states - mean) / std
    
    return states, next_states, mean, std

2. 网络架构:V, Q 与 Policy

IQL 需要三个网络:

  1. Q Network (Twin) :评估 ( s , a ) (s, a) (s,a) 的价值。为了稳定,通常用两个 Q 网络 ( Q 1 , Q 2 Q_1, Q_2 Q1,Q2)。
  2. V Network :评估状态 s s s 的价值(作为 Expectile)。
  3. Policy Network:输出动作分布(通常是 Gaussian)。
python 复制代码
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.net(x)

# 策略网络通常输出均值和方差
class GaussianPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        self.mu = nn.Linear(256, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim)) # 可学习的 log_std

    def forward(self, state):
        x = self.net(state)
        mu = self.mu(x)
        # 限制 log_std 范围,防止方差过大或过小(关键 Trick)
        log_std = torch.clamp(self.log_std, -20, 2) 
        std = torch.exp(log_std)
        return torch.distributions.Normal(mu, std)
    
    def get_action(self, state, deterministic=False):
        dist = self.forward(state)
        if deterministic:
            return torch.tanh(dist.mean) # 测试时用均值
        return torch.tanh(dist.sample()) # 训练时采样

3. 核心逻辑:IQL 的三个 Loss

IQL 的核心是非对称的 Expectile Loss

python 复制代码
def expectile_loss(diff, expectile=0.7):
    # diff = Q - V
    # 当 Q > V 时 (diff > 0),权重为 expectile (比如 0.7)
    # 当 Q < V 时 (diff < 0),权重为 1-expectile (比如 0.3)
    # 这会使 V 倾向于靠近 Q 分布的上边缘
    weight = torch.where(diff > 0, expectile, (1 - expectile))
    return torch.mean(weight * (diff ** 2))

4. 完整的 Update Step 代码

将所有组件拼装起来。注意 Target Network 的使用和梯度的阻断。

python 复制代码
class IQL_Agent:
    def __init__(self, state_dim, action_dim, device):
        self.q1 = MLP(state_dim + action_dim, 1).to(device)
        self.q2 = MLP(state_dim + action_dim, 1).to(device)
        self.target_q1 = copy.deepcopy(self.q1) # Target Q用于稳定训练
        self.target_q2 = copy.deepcopy(self.q2)
        
        self.v = MLP(state_dim, 1).to(device)
        self.actor = GaussianPolicy(state_dim, action_dim).to(device)
        
        # 优化器
        self.q_optimizer = torch.optim.Adam(list(self.q1.parameters()) + list(self.q2.parameters()), lr=3e-4)
        self.v_optimizer = torch.optim.Adam(self.v.parameters(), lr=3e-4)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        
        self.expectile = 0.7  # IQL 核心超参
        self.temperature = 3.0 # AWR 核心超参
        self.gamma = 0.99
        self.tau = 0.005 # 软更新系数

    def update(self, batch):
        states, actions, rewards, next_states, dones = batch
        
        # ---------------------------------------
        # 1. Update V (Expectile Regression)
        # ---------------------------------------
        with torch.no_grad():
            # 使用 Target Q 来计算 V 的目标,更稳定
            q1_t = self.target_q1(torch.cat([states, actions], dim=1))
            q2_t = self.target_q2(torch.cat([states, actions], dim=1))
            min_q = torch.min(q1_t, q2_t)
            
        v_pred = self.v(states)
        v_loss = expectile_loss(min_q - v_pred, self.expectile)
        
        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()
        
        # ---------------------------------------
        # 2. Update Q (MSE Loss)
        # ---------------------------------------
        with torch.no_grad():
            next_v = self.v(next_states)
            # 关键:IQL 的 Q target 使用 V(s'),不使用 max Q(s', a')
            q_target = rewards + self.gamma * (1 - dones) * next_v

        q1_pred = self.q1(torch.cat([states, actions], dim=1))
        q2_pred = self.q2(torch.cat([states, actions], dim=1))
        q_loss = F.mse_loss(q1_pred, q_target) + F.mse_loss(q2_pred, q_target)
        
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        # ---------------------------------------
        # 3. Update Policy (Advantage Weighted Regression)
        # ---------------------------------------
        with torch.no_grad():
            # 计算优势函数 A(s, a) = Q(s, a) - V(s)
            q1 = self.target_q1(torch.cat([states, actions], dim=1))
            q2 = self.target_q2(torch.cat([states, actions], dim=1))
            min_q = torch.min(q1, q2)
            v = self.v(states)
            advantage = min_q - v
            
            # 计算权重 exp(A / T)
            exp_adv = torch.exp(advantage / self.temperature)
            # 限制权重上限,防止数值不稳定
            exp_adv = torch.clamp(exp_adv, max=100.0)

        # 计算 Policy 的 log_prob(a|s)
        dist = self.actor(states)
        log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True)
        
        # Loss = - weights * log_prob (加权最大似然)
        actor_loss = -(exp_adv * log_prob).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # ---------------------------------------
        # 4. Soft Update Target Networks
        # ---------------------------------------
        self.soft_update(self.q1, self.target_q1)
        self.soft_update(self.q2, self.target_q2)

    def soft_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)

5. 稳定训练的工程技巧 (Tricks)

如果只写上面的代码,你可能只能在简单任务上跑通。想在 AntMaze 上拿分,还需要以下 Trick:

  1. Cosine Learning Rate Decay
    Offline RL 容易过拟合。在训练最后阶段将学习率衰减到 0,能显著提升测试性能。

    python 复制代码
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps)
  2. LayerNorm
    在 MLP 的 ReLU 之前加入 nn.LayerNorm(),对于防止 Q 值发散非常有用。

  3. Orthogonal Initialization
    使用正交初始化网络参数,比默认的 Xavier 初始化收敛更快。

  4. Target Update 频率
    IQL 中 V 的更新很快,Q 的 Target Network 更新可以适当慢一点,或者不使用 Target Network 直接用当前的 Q 也可以(IQL 论文中有些变体是这样做的),但保留 Target Q 通常更稳。


6. 常见 Bug 与排查方法 🛠️

6.1 Q Loss 不下降 / 震荡

  • 原因:State 没有归一化。
  • 排查:打印 State 的 mean 和 std,如果 mean 不是 0 附近,必挂。

6.2 Policy Loss 变成 NaN

  • 原因exp(advantage / temperature) 溢出。
  • 排查 :检查 Advantage 的数值范围。如果 A 很大(比如 100),exp(30) 就会很大。一定要加 torch.clamp

6.3 训练出来的 Agent 一动不动

  • 原因:Temperature 太小,或者 Expectile 太大。
  • 排查
    • 如果 temperature 太小(如 0.1),Policy 只会模仿那些极少数 Advantage 极大的样本,导致过拟合。
    • 如果 expectile 太大(如 0.99),V 值会估计得非常高,导致 Advantage 几乎全是负的,Policy 学不到东西。推荐默认值:Expectile=0.7, Temperature=3.0

6.4 测试分数极低,但 Q 值很高

  • 原因:Overestimation(尽管 IQL 已经很克制了,但依然可能发生)。
  • 排查:IQL 的 Q 值不应该特别大。如果发现 Q 值远超 Max Episode Return,说明 Target 计算有问题,或者 Reward Scale 太大(建议把 Reward 归一化到 [0, 1] 或做简单的 Scaling)。

结语

从零实现 Offline RL 是一个痛苦但收益巨大的过程。你会发现它不再是黑盒,而是由一个个精巧的积木(Expectile, AWR, Normalization)搭建的城堡。

现在的你,已经具备了手写 SOTA 算法的能力,去 D4RL 榜单上试试身手吧!

相关推荐
rayufo2 小时前
深度学习图像复原论文《SwinIR: Image Restoration Using Swin Transformer》解读及其代码实现
人工智能·深度学习·transformer
万俟淋曦2 小时前
【论文速递】2025年第42周(Oct-12-18)(Robotics/Embodied AI/LLM)
人工智能·ai·机器人·大模型·论文·robotics·具身智能
hero_heart2 小时前
opencv和摄影测量坐标系的转换
人工智能·opencv·计算机视觉
Java后端的Ai之路2 小时前
【分析式AI】-时间序列模型一文详解
人工智能·aigc·时间序列·算法模型·分析式ai
AI即插即用2 小时前
即插即用系列 | CMPB PMFSNet:多尺度特征自注意力网络,打破轻量级医学图像分割的性能天花板
网络·图像处理·人工智能·深度学习·神经网络·计算机视觉·视觉检测
love530love2 小时前
在 PyCharm 中配置 x64 Native Tools Command Prompt for VS 2022 作为默认终端
ide·人工智能·windows·python·pycharm·prompt·comfyui
图导物联2 小时前
商场室内导航系统:政策适配 + 技术实现 + 代码示例,打通停车逛店全流程
大数据·人工智能·物联网
柒.梧.2 小时前
CSS 基础样式与盒模型详解:从入门到实战进阶
人工智能·python·tensorflow
WLJT1231231232 小时前
“人工智能+”引领数字产业迈入价值兑现新阶段
人工智能