摘要 :
纸上得来终觉浅,绝知此事要躬行。看懂了论文公式,不代表能写对代码。在 Offline RL 中,数据处理的细节 、网络初始化的技巧 以及Loss 的计算顺序,往往比算法原理本身更能决定成败。本文将带你从零构建一个完整的 IQL 训练流程,涵盖 D4RL 数据加载、归一化处理、核心 Loss 实现以及工业级的训练 Trick。
目录
- 准备工作:数据加载与归一化
- [网络架构:V, Q 与 Policy](#网络架构:V, Q 与 Policy)
- [核心逻辑:IQL 的三个 Loss](#核心逻辑:IQL 的三个 Loss)
- [完整的 Update Step 代码](#完整的 Update Step 代码)
- 稳定训练的工程技巧 (Tricks)
- [常见 Bug 与排查方法](#常见 Bug 与排查方法)
1. 准备工作:数据加载与归一化
这是 Offline RL 中最重要的一步! 90% 的失败案例都是因为没有对 State 进行归一化。
1.1 加载 D4RL
首先你需要安装 d4rl。D4RL 的数据集通常包含 observations, actions, rewards, terminals 等字段。
1.2 标准化 (Normalization)
由于 State 的不同维度可能有巨大的数值差异(例如位置坐标是 100,而速度是 0.01),直接训练会导致梯度爆炸或收敛极慢。我们必须把 State 归一化到 均值为 0,方差为 1。
python
import torch
import numpy as np
import d4rl
import gym
def get_dataset(env):
dataset = d4rl.qlearning_dataset(env)
# 转换为 Tensor
states = torch.from_numpy(dataset['observations']).float()
actions = torch.from_numpy(dataset['actions']).float()
rewards = torch.from_numpy(dataset['rewards']).float()
next_states = torch.from_numpy(dataset['next_observations']).float()
dones = torch.from_numpy(dataset['terminals']).float()
return states, actions, rewards, next_states, dones
def normalize_states(states, next_states):
# 计算统计量
mean = states.mean(dim=0, keepdim=True)
std = states.std(dim=0, keepdim=True) + 1e-3 # 防止除零
# 归一化
states = (states - mean) / std
next_states = (next_states - mean) / std
return states, next_states, mean, std
2. 网络架构:V, Q 与 Policy
IQL 需要三个网络:
- Q Network (Twin) :评估 ( s , a ) (s, a) (s,a) 的价值。为了稳定,通常用两个 Q 网络 ( Q 1 , Q 2 Q_1, Q_2 Q1,Q2)。
- V Network :评估状态 s s s 的价值(作为 Expectile)。
- Policy Network:输出动作分布(通常是 Gaussian)。
python
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
# 策略网络通常输出均值和方差
class GaussianPolicy(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU()
)
self.mu = nn.Linear(256, action_dim)
self.log_std = nn.Parameter(torch.zeros(action_dim)) # 可学习的 log_std
def forward(self, state):
x = self.net(state)
mu = self.mu(x)
# 限制 log_std 范围,防止方差过大或过小(关键 Trick)
log_std = torch.clamp(self.log_std, -20, 2)
std = torch.exp(log_std)
return torch.distributions.Normal(mu, std)
def get_action(self, state, deterministic=False):
dist = self.forward(state)
if deterministic:
return torch.tanh(dist.mean) # 测试时用均值
return torch.tanh(dist.sample()) # 训练时采样
3. 核心逻辑:IQL 的三个 Loss
IQL 的核心是非对称的 Expectile Loss。
python
def expectile_loss(diff, expectile=0.7):
# diff = Q - V
# 当 Q > V 时 (diff > 0),权重为 expectile (比如 0.7)
# 当 Q < V 时 (diff < 0),权重为 1-expectile (比如 0.3)
# 这会使 V 倾向于靠近 Q 分布的上边缘
weight = torch.where(diff > 0, expectile, (1 - expectile))
return torch.mean(weight * (diff ** 2))
4. 完整的 Update Step 代码
将所有组件拼装起来。注意 Target Network 的使用和梯度的阻断。
python
class IQL_Agent:
def __init__(self, state_dim, action_dim, device):
self.q1 = MLP(state_dim + action_dim, 1).to(device)
self.q2 = MLP(state_dim + action_dim, 1).to(device)
self.target_q1 = copy.deepcopy(self.q1) # Target Q用于稳定训练
self.target_q2 = copy.deepcopy(self.q2)
self.v = MLP(state_dim, 1).to(device)
self.actor = GaussianPolicy(state_dim, action_dim).to(device)
# 优化器
self.q_optimizer = torch.optim.Adam(list(self.q1.parameters()) + list(self.q2.parameters()), lr=3e-4)
self.v_optimizer = torch.optim.Adam(self.v.parameters(), lr=3e-4)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
self.expectile = 0.7 # IQL 核心超参
self.temperature = 3.0 # AWR 核心超参
self.gamma = 0.99
self.tau = 0.005 # 软更新系数
def update(self, batch):
states, actions, rewards, next_states, dones = batch
# ---------------------------------------
# 1. Update V (Expectile Regression)
# ---------------------------------------
with torch.no_grad():
# 使用 Target Q 来计算 V 的目标,更稳定
q1_t = self.target_q1(torch.cat([states, actions], dim=1))
q2_t = self.target_q2(torch.cat([states, actions], dim=1))
min_q = torch.min(q1_t, q2_t)
v_pred = self.v(states)
v_loss = expectile_loss(min_q - v_pred, self.expectile)
self.v_optimizer.zero_grad()
v_loss.backward()
self.v_optimizer.step()
# ---------------------------------------
# 2. Update Q (MSE Loss)
# ---------------------------------------
with torch.no_grad():
next_v = self.v(next_states)
# 关键:IQL 的 Q target 使用 V(s'),不使用 max Q(s', a')
q_target = rewards + self.gamma * (1 - dones) * next_v
q1_pred = self.q1(torch.cat([states, actions], dim=1))
q2_pred = self.q2(torch.cat([states, actions], dim=1))
q_loss = F.mse_loss(q1_pred, q_target) + F.mse_loss(q2_pred, q_target)
self.q_optimizer.zero_grad()
q_loss.backward()
self.q_optimizer.step()
# ---------------------------------------
# 3. Update Policy (Advantage Weighted Regression)
# ---------------------------------------
with torch.no_grad():
# 计算优势函数 A(s, a) = Q(s, a) - V(s)
q1 = self.target_q1(torch.cat([states, actions], dim=1))
q2 = self.target_q2(torch.cat([states, actions], dim=1))
min_q = torch.min(q1, q2)
v = self.v(states)
advantage = min_q - v
# 计算权重 exp(A / T)
exp_adv = torch.exp(advantage / self.temperature)
# 限制权重上限,防止数值不稳定
exp_adv = torch.clamp(exp_adv, max=100.0)
# 计算 Policy 的 log_prob(a|s)
dist = self.actor(states)
log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True)
# Loss = - weights * log_prob (加权最大似然)
actor_loss = -(exp_adv * log_prob).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# ---------------------------------------
# 4. Soft Update Target Networks
# ---------------------------------------
self.soft_update(self.q1, self.target_q1)
self.soft_update(self.q2, self.target_q2)
def soft_update(self, local_model, target_model):
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)
5. 稳定训练的工程技巧 (Tricks)
如果只写上面的代码,你可能只能在简单任务上跑通。想在 AntMaze 上拿分,还需要以下 Trick:
-
Cosine Learning Rate Decay :
Offline RL 容易过拟合。在训练最后阶段将学习率衰减到 0,能显著提升测试性能。pythonscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps) -
LayerNorm :
在 MLP 的 ReLU 之前加入nn.LayerNorm(),对于防止 Q 值发散非常有用。 -
Orthogonal Initialization :
使用正交初始化网络参数,比默认的 Xavier 初始化收敛更快。 -
Target Update 频率 :
IQL 中 V 的更新很快,Q 的 Target Network 更新可以适当慢一点,或者不使用 Target Network 直接用当前的 Q 也可以(IQL 论文中有些变体是这样做的),但保留 Target Q 通常更稳。
6. 常见 Bug 与排查方法 🛠️
6.1 Q Loss 不下降 / 震荡
- 原因:State 没有归一化。
- 排查:打印 State 的 mean 和 std,如果 mean 不是 0 附近,必挂。
6.2 Policy Loss 变成 NaN
- 原因 :
exp(advantage / temperature)溢出。 - 排查 :检查 Advantage 的数值范围。如果 A 很大(比如 100),exp(30) 就会很大。一定要加
torch.clamp。
6.3 训练出来的 Agent 一动不动
- 原因:Temperature 太小,或者 Expectile 太大。
- 排查 :
- 如果
temperature太小(如 0.1),Policy 只会模仿那些极少数 Advantage 极大的样本,导致过拟合。 - 如果
expectile太大(如 0.99),V 值会估计得非常高,导致 Advantage 几乎全是负的,Policy 学不到东西。推荐默认值:Expectile=0.7, Temperature=3.0。
- 如果
6.4 测试分数极低,但 Q 值很高
- 原因:Overestimation(尽管 IQL 已经很克制了,但依然可能发生)。
- 排查:IQL 的 Q 值不应该特别大。如果发现 Q 值远超 Max Episode Return,说明 Target 计算有问题,或者 Reward Scale 太大(建议把 Reward 归一化到 [0, 1] 或做简单的 Scaling)。
结语
从零实现 Offline RL 是一个痛苦但收益巨大的过程。你会发现它不再是黑盒,而是由一个个精巧的积木(Expectile, AWR, Normalization)搭建的城堡。
现在的你,已经具备了手写 SOTA 算法的能力,去 D4RL 榜单上试试身手吧!