Decision Transformer (DT) 是强化学习(RL)领域的一个里程碑式的工作(由伯克利、Facebook 等机构于 2021 年提出)。它的核心思想是:将强化学习问题抽象为序列建模(Sequence Modeling)问题,从而直接套用 Transformer 架构来解决 RL 任务。
在传统 RL 中,我们通常使用时间差分学习(TD-error)来拟合价值函数(如 Q-learning)或策略梯度(如 PPO)。而 Decision Transformer 彻底抛弃了这些传统的 RL 状态价值计算,把"马尔可夫决策过程"直接看作是一个语言生成任务。
核心机制:条件序列建模传统
Transformer 接收文字 Token 并预测下一个 Token。Decision Transformer 接收的则是状态(State)、动作(Action)和回报(Reward)的序列。
为了让模型能够自适应地输出不同质量的动作,DT 引入了一个关键概念:期望回报(Return-to-Go, RtR_tRt)。
Rt=∑t′=tTrt′R_t = \sum_{t'=t}^T r_{t'}Rt=t′=t∑Trt′
RtR_tRt 表示从当前时间步 ttt 开始,到轨迹结束预计能拿到的总回报。
1. 轨迹表示(Trajectory Representation)
在 DT 中,一个轨迹被表示为如下的序列,作为 Transformer 的输入:
τ=(R^1,s1,a1,R^2,s2,a2,...,R^T,sT,aT)\tau = \left( \hat{R}_1, s_1, a_1, \hat{R}_2, s_2, a_2, \dots, \hat{R}_T, s_T, a_T \right)τ=(R^1,s1,a1,R^2,s2,a2,...,R^T,sT,aT)
- R^t\hat{R}_tR^t:时间步 ttt 的期望回报(Return-to-Go)。
- sts_tst:时间步 ttt 的状态(State)。
- ata_tat:时间步 ttt 的动作(Action)。
2. 模型架构DT
采用的是类似于 GPT 的 Decoder-only 自回归架构。由于输入序列包含三种不同的模态(回报、状态、动作),模型首先会通过不同的线性层(或针对图像状态的 CNN)将它们投影到相同的嵌入维度(Embedding Dimension),并加上时间步位置编码(Time-step Embedding)。
3. 训练阶段(离线强化学习 / Offline RL)
DT 主要用于 Offline RL 场景(即从固定的历史数据集里学习,不与环境进行实时交互):
- 从离线数据集中采样一段轨迹。
- 将序列 (R^1,s1,a1,... )\left( \hat{R}_1, s_1, a_1, \dots \right)(R^1,s1,a1,...) 输入给模型。
- 预测目标 :模型利用因果掩码(Causal Mask),基于过去的所有信息以及当前的 R^t\hat{R}_tR^t 和 sts_tst,来预测当前的动作 ata_tat。
- 损失函数:如果是连续动作,使用均方误差(MSE);如果是离线离散动作,使用交叉熵损失(Cross Entropy)。
L=E∥at−DT(R\^1:t,s1:t,a1:t−1)∥2\mathcal{L} = \mathbb{E} \left \\\| a_t - \\text{DT}(\\hat{R}_{1:t}, s_{1:t}, a_{1:t-1}) \\\|\^2 \\rightL=E∥at−DT(R\^1:t,s1:t,a1:t−1)∥2
测试/推理阶段:如何通过"许愿"来调控表现
这是 Decision Transformer 最有意思的地方。在实际测试(Deployment)时,我们不需要写任何复杂的探索策略,而是直接向模型"许愿"。
- 设定初始期望:你想让模型拿满分,你就在初始步输入一个极高的回报目标 R^1\hat{R}_1R^1(比如 1.0 或该游戏的最大可能得分)。
- 生成动作:将 R^1\hat{R}_1R^1 和初始状态 s1s_1s1 输入给 DT,模型会输出对应的动作 a1a_1a1。
- 环境反馈与更新:将 a1a_1a1 作用于环境,得到真实的环境奖励 r1r_1r1 和新状态 s2s_2s2。
- 减去损耗,继续许愿:更新下一步的期望回报:R^2=R^1−r1\hat{R}_2 = \hat{R}_1 - r_1R^2=R^1−r1。
- 滚动推进:把新的条件序列输入给模型,自回归地生成 a2a_2a2,以此类推。
为什么 DT 这么有效?(对比传统 RL)
| 特性 | 传统 Offline RL (如 CQL, BEAR) | Decision Transformer (DT) |
|---|---|---|
| 核心机制 | 动态规划 (Dynamic Programming) / TD 学习 | 序列建模 (Sequence Modeling) / 注意力机制 |
| 致命弱点 | 致命三要素 (Deadly Triad):函数逼近、Bootstrap、离线数据会导致 Q 值估计严重爆炸(Overestimation Bias)。 | 没有 Q 值的概念,不使用 Bootstrap,完全不存在值爆炸问题,训练极度稳定。 |
| 长距离信用分配 | 通过 Bellman 方程一步步回传价值,速度慢,易受局部噪声影响。 | 通过 Self-Attention 直接在整个上下文窗口内建立关联,能更好地实现"长距离信用分配"。 |
| 行为泛化 | 很难直接从失败的经验中学习。 | 只要数据中包含"高回报"和"低回报"的对比,模型就能理解什么是好、什么是坏,并定向生成高质量行为。 |
局限性
- 无法超越数据集的上限(没有超越性探索) :DT 本质上是在做条件行为克隆(Conditional Behavioral Cloning)。如果你的数据集中从来没有过高回报的轨迹,你盲目输入一个极大的 R^t\hat{R}_tR^t,模型会进入"未分布区域(OOD)",表现可能会崩塌。
- 对上下文窗口(Context Window)的依赖:由于它依赖马尔可夫决策过程在历史序列中的体现,如果环境的延迟奖励极长,超出了 Transformer 的 max_len,性能会大幅下降。
- 计算资源消耗:Transformer 的自回归推理成本明显高于传统 RL 的微型 MLP 策略网络。
总结来说,Decision Transformer 证明了"大语言模型的那一套完全可以用在控制和决策任务上"。它把强化学习的问题转化为了一个纯粹的、死记硬背加联想的序列预测问题,为后来各种具身智能(Embodied AI)大模型(如 RT-1, RT-2)奠定了核心的算法范式。
核心代码实现 (PyTorch 极简版)
Decision Transformer 的关键在于:将三种模态(R,s,aR, s, aR,s,a)投影到同一维度,融合成一个交错的序列,然后送入 GPT 架构中。
import torch
import torch.nn as nn
class MinDT(nn.Module):
def __init__(self, state_dim, act_dim, hidden_dim, max_length=20):
super().__init__()
self.state_dim = state_dim
self.act_dim = act_dim
self.hidden_dim = hidden_dim
self.max_length = max_length
# 1. 模态专用的 Embedding 层
self.embed_rtg = nn.Linear(1, hidden_dim)
self.embed_state = nn.Linear(state_dim, hidden_dim)
self.embed_action = nn.Linear(act_dim, hidden_dim)
# 2. 时间步位置编码 (Time-step Embedding)
# 区别于 NLP 的 Token 位置,这里使用的是绝对时间步 t
self.embed_timestep = nn.Embedding(max_length * 5, hidden_dim)
# 3. 因果 Transformer 解码器 (用一层标准的 Transformer Encoder 配合 Causal Mask 模拟)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=hidden_dim*4,
batch_first=True, activation='gelu', norm_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=3)
# 4. 预测输出头 (基于当前上下文预测 action)
self.predict_action = nn.Sequential(
nn.Linear(hidden_dim, act_dim)
)
def forward(self, states, actions, rtgs, timesteps):
# states: (batch, seq_len, state_dim)
# actions: (batch, seq_len, act_dim)
# rtgs: (batch, seq_len, 1)
# timesteps: (batch, seq_len)
batch_size, seq_len, _ = states.shape
# 计算各个模态的 embedding
rtg_embeds = self.embed_rtg(rtgs)
state_embeds = self.embed_state(states)
act_embeds = self.embed_action(actions)
time_embeds = self.embed_timestep(timesteps).unsqueeze(2) # (batch, seq_len, 1, hidden_dim)
# 将时间步编码叠加到各个模态上
rtg_embeds = rtg_embeds + time_embeds.squeeze(2)
state_embeds = state_embeds + time_embeds.squeeze(2)
act_embeds = act_embeds + time_embeds.squeeze(2)
# 核心交错组装: (R1, S1, A1, R2, S2, A2...)
# 组装后的序列长度为 3 * seq_len
stacked_inputs = torch.stack((rtg_embeds, state_embeds, act_embeds), dim=2)
stacked_inputs = stacked_inputs.permute(0, 1, 3, 2).reshape(batch_size, 3 * seq_len, self.hidden_dim)
# 生成因果掩码 (Causal Mask) 防止看到未来信息
mask = torch.triu(torch.ones(3 * seq_len, 3 * seq_len), diagonal=1).bool().to(states.device)
# 喂入 Transformer
transformer_outputs = self.transformer(stacked_inputs, mask=mask)
# 提取对应位置的输出去预测 Action
# 我们希望用 (R_t, S_t) 的输出来预测 A_t。在交错序列中,S_t 处于索引 3*t + 1 的位置
x = transformer_outputs[:, 1::3, :] # 步长为3,从索引1开始切片
# 预测动作
action_preds = self.predict_action(x)
return action_preds
经典实例:Gym 悬崖寻路/走格子 (GridWorld / Gym-CartPole)
为了直观理解它在推理时如何"许愿",我们来看一个完整的离线部署与评估实例。假设我们已经在 CartPole(倒立摆) 的离线数据集(包含大量不同得分的轨迹)上训练好了上面的 MinDT 模型。
下面的实例展示了在测试阶段,如何通过自回归的方式给出高分承诺并控制游戏:
import gym
import numpy as np
def evaluate_dt_agent(env_name="CartPole-v1", target_return=500.0):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.n # 假设离散动作在实际中转为 one-hot,这里为简化用连续表示逻辑
# 实例化模型 (假设已经加载了训练好的权重)
model = MinDT(state_dim=state_dim, act_dim=1, hidden_dim=128, max_length=50)
model.eval()
# 初始化测试环境
state, _ = env.reset()
# 初始化决策序列上下文 (固定最大窗口长度,如 20)
max_window = 20
states_buffer = torch.zeros((1, max_window, state_dim), dtype=torch.float32)
actions_buffer = torch.zeros((1, max_window, 1), dtype=torch.float32)
rtgs_buffer = torch.zeros((1, max_window, 1), dtype=torch.float32)
timesteps_buffer = torch.zeros((1, max_window), dtype=torch.long)
# 1. 注入"许愿值":目标是拿到满分 500 分
current_target_return = target_return
states_buffer[0, 0] = torch.from_numpy(state)
rtgs_buffer[0, 0] = torch.tensor([current_target_return])
timesteps_buffer[0, 0] = 0
total_reward = 0
done = False
t = 0
while not done and t < 500:
# 确定当前窗口内有效切片长度
# 随着 t 增加,窗口填满后只保留最近的 max_window 步(类似于滑动窗口)
curr_len = min(t + 1, max_window)
s_input = states_buffer[:, :curr_len, :]
a_input = actions_buffer[:, :curr_len, :]
r_input = rtgs_buffer[:, :curr_len, :]
t_input = timesteps_buffer[:, :curr_len]
# 2. 模型预测
with torch.no_grad():
action_preds = model(s_input, a_input, r_input, t_input)
# 拿最新一个时间步的预测结果
pred_action = action_preds[0, -1, :].item()
# CartPole 是离散动作,这里做简单的阈值二值化
action = 1 if pred_action > 0.5 else 0
# 3. 与环境交互
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
total_reward += reward
# 4. 更新许愿值 (Return-to-go 减去刚刚拿到的奖励)
current_target_return -= reward
# 5. 将新数据存入 buffer,推进到下一个时间步
t += 1
if t < max_window:
states_buffer[0, t] = torch.from_numpy(next_state)
actions_buffer[0, t-1] = torch.tensor([action]) # 补充上一步执行的 action
rtgs_buffer[0, t] = torch.tensor([current_target_return])
timesteps_buffer[0, t] = t
else:
# 窗口满了,将 buffer 整体前移,并在末尾覆盖新数据
states_buffer = torch.cat([states_buffer[:, 1:, :], torch.from_numpy(next_state).unsqueeze(0).unsqueeze(0)], dim=1)
actions_buffer[:, :-1, :] = actions_buffer[:, 1:, :]
actions_buffer[:, -1, :] = torch.tensor([action])
rtgs_buffer = torch.cat([rtgs_buffer[:, 1:, :], torch.tensor([current_target_return]).unsqueeze(0).unsqueeze(0)], dim=1)
timesteps_buffer = torch.cat([timesteps_buffer[:, 1:], torch.tensor([t]).unsqueeze(0)], dim=1)
state = next_state
print(f"设定目标总分: {target_return}, 最终实际得分: {total_reward}")
env.close()
# evaluate_dt_agent() # 运行示例
在工业界或研究中如何真正跑起来?
如果你要在学术研究(如 D4RL 数据集)或实际业务中去魔改 DT,完全不需要自己从头一行行撸上述所有的 Buffer 拼接。推荐使用官方或主流开源生态:
方案 A:使用 Hugging Face transformers (推荐)
Hugging Face 已经将 Decision Transformer 抽象成了标准库。
from transformers import DecisionTransformerModel, DecisionTransformerConfig
# 初始化配置
config = DecisionTransformerConfig(state_dim=17, act_dim=6, max_ep_len=1000)
# 创建模型
model = DecisionTransformerModel(config)
# 它的 forward 接收:states, actions, rewards, returns_to_go, timesteps, attention_mask
# 输出具有标准的 loss 和 action 预测
方案 B:使用 CORL 开源库 (更偏向 RL 实验)
如果你在做强化学习算法基准测试,推荐去 GitHub 搜 CORL (Clean Offline Reinforcement Learning) 。它提供了极其干净、单文件(Single-file implementation)的 Decision Transformer 实现,并且完美对齐了 D4RL(如 MuJoCo 机器人控制任务)的官方论文指标。