一、元强化学习原理
1. 元学习核心思想
元强化学习(Meta-RL)旨在让智能体快速适应新任务 ,其核心是通过任务分布学习共享知识。与传统强化学习的区别在于:
对比维度 | 传统强化学习 | 元强化学习 |
---|---|---|
目标 | 解决单一任务 | 快速适应任务分布中的新任务 |
训练方式 | 单任务大量交互 | 多任务交替训练 |
泛化能力 | 任务特定策略 | 跨任务可迁移策略 |
2. MAML 算法框架
Model-Agnostic Meta-Learning (MAML) 通过双层优化实现快速适应:
-
内层循环:在单个任务上执行少量梯度步
-
外层循环:跨任务更新初始参数
数学表达:

二、MAML 实现步骤(基于 Gymnasium)
我们将以 HalfCheetah 变体任务 为例,实现 MAML 算法:
-
定义任务分布:修改机器人质量参数生成不同任务
-
构建策略网络:基于 PyTorch 的 Actor-Critic 架构
-
实现双层优化:内层任务适配 + 外层元更新
-
快速适应测试:在新任务上验证策略性能
三、代码实现
python
import gymnasium as gym
import torch
import numpy as np
from torch import nn, optim
from collections import deque
import time
import torch.nn.functional as F
# ================== 配置参数优化 ==================
class MAMLConfig:
env_name = "HalfCheetah-v5"
num_tasks = 20
adaptation_steps = 10 # 增加适应步数
adaptation_lr = 0.1 # 调整适应学习率
hidden_dim = 256 # 增大隐藏层维度
gamma = 0.99
tau = 0.95 # 用于GAE计算
meta_batch_size = 8 # 增大元批量
meta_lr = 3e-4 # 调整元学习率
total_epochs = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_grad = 0.5 # 梯度裁剪阈值
# ================== 策略网络优化 ==================
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
# 独立特征提取层(修正结构命名)
self.actor_net = nn.Sequential(
nn.Linear(state_dim, MAMLConfig.hidden_dim),
nn.LayerNorm(MAMLConfig.hidden_dim),
nn.Tanh(),
nn.Linear(MAMLConfig.hidden_dim, MAMLConfig.hidden_dim),
nn.LayerNorm(MAMLConfig.hidden_dim),
nn.Tanh()
)
self.critic_net = nn.Sequential(
nn.Linear(state_dim, MAMLConfig.hidden_dim),
nn.LayerNorm(MAMLConfig.hidden_dim),
nn.Tanh(),
nn.Linear(MAMLConfig.hidden_dim, MAMLConfig.hidden_dim),
nn.LayerNorm(MAMLConfig.hidden_dim),
nn.Tanh()
)
self.actor_mean = nn.Linear(MAMLConfig.hidden_dim, action_dim)
self.log_std = nn.Parameter(torch.zeros(action_dim))
self.critic = nn.Linear(MAMLConfig.hidden_dim, 1)
# 初始化参数(保持原有初始化逻辑)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=0.01) # 正交初始化
nn.init.constant_(m.bias, 0)
# 策略最后一层初始化较小
nn.init.orthogonal_(self.actor_mean.weight, gain=0.01)
nn.init.constant_(self.actor_mean.bias, 0)
# 价值头初始化
nn.init.orthogonal_(self.critic.weight, gain=1.0)
nn.init.constant_(self.critic.bias, 0)
def forward(self, state, params=None):
if params is None:
# 正常前向传播
actor_features = self.actor_net(state)
critic_features = self.critic_net(state)
mean = self.actor_mean(actor_features)
value = self.critic(critic_features).squeeze(-1)
else:
# 手动参数计算时保持维度一致性
if len(state.shape) == 1:
state = state.unsqueeze(0) # 添加批量维度
# Actor网络计算
x = F.linear(state,
params['actor_net.0.weight'],
params['actor_net.0.bias'])
x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
x = torch.tanh(x)
x = F.linear(x,
params['actor_net.3.weight'],
params['actor_net.3.bias'])
x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
actor_features = torch.tanh(x)
# Critic网络计算
x = F.linear(state,
params['critic_net.0.weight'],
params['critic_net.0.bias'])
x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
x = torch.tanh(x)
x = F.linear(x,
params['critic_net.3.weight'],
params['critic_net.3.bias'])
x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
critic_features = torch.tanh(x)
mean = F.linear(actor_features,
params['actor_mean.weight'],
params['actor_mean.bias'])
value = F.linear(critic_features,
params['critic.weight'],
params['critic.bias']).squeeze(-1)
log_std = self.log_std.unsqueeze(0).expand(mean.shape[0], -1)
return mean, log_std, value
def sample_action(self, state, params=None):
mean, log_std, _ = self.forward(state, params)
std = log_std.exp()
dist = torch.distributions.Normal(mean, std)
action = dist.rsample()
# 新增维度检查逻辑
if len(action.shape) > 1:
if action.shape[0] == 1: # 单样本批量情况
action = action.squeeze(0)
else: # 多步采样情况
action = action.squeeze()
log_prob = dist.log_prob(action).sum(-1)
return action.detach(), log_prob
# ================== 任务生成器优化 ==================
class TaskGenerator:
def __init__(self):
self.default_params = self._get_default_params()
def _get_default_params(self):
env = gym.make(MAMLConfig.env_name)
params = {
'mass': env.unwrapped.model.body_mass.copy(),
'damping': env.unwrapped.model.dof_damping.copy()
}
env.close()
return params
def sample_task(self):
new_params = {
'mass': self.default_params['mass'] * np.random.uniform(0.5, 2.0, size=self.default_params['mass'].shape),
'damping': self.default_params['damping'] * np.random.uniform(0.5, 2.0, size=self.default_params['damping'].shape),
'ctrlrange': self.default_params['damping'] * np.random.uniform(0.8, 1.2) # 新增控制力范围扰动
}
return new_params
# ================== MAML 训练系统优化 ==================
class MAMLTrainer:
def __init__(self):
self.env = gym.make(MAMLConfig.env_name)
self.state_dim = self.env.observation_space.shape[0]
self.action_dim = self.env.action_space.shape[0]
self.policy = ActorCritic(self.state_dim, self.action_dim).to(MAMLConfig.device)
self.meta_optimizer = optim.Adam(self.policy.parameters(), lr=MAMLConfig.meta_lr, betas=(0.9, 0.999))
self.task_gen = TaskGenerator()
self.tasks = [self.task_gen.sample_task() for _ in range(MAMLConfig.num_tasks)]
def adapt_task(self, task_params, num_steps):
env = gym.make(MAMLConfig.env_name)
env.unwrapped.model.body_mass[:] = task_params['mass']
env.unwrapped.model.dof_damping[:] = task_params['damping']
fast_weights = {k: v.clone().requires_grad_(True) for k, v in self.policy.named_parameters()}
# 多步适应过程
for step in range(num_steps):
states, actions, rewards, values, dones = [], [], [], [], []
obs, _ = env.reset()
done = False
while not done:
with torch.no_grad():
state_tensor = torch.FloatTensor(obs).to(MAMLConfig.device)
action, _ = self.policy.sample_action(state_tensor, params=fast_weights)
_, _, value = self.policy(state_tensor, params=fast_weights)
# next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy().astype(np.float32).flatten() # 新增flatten()
)
states.append(obs)
actions.append(action)
rewards.append(reward)
values.append(value)
dones.append(terminated or truncated)
obs = next_obs
done = terminated or truncated
# 计算GAE
with torch.no_grad():
last_value = self.policy(torch.FloatTensor(obs).to(MAMLConfig.device), params=fast_weights)[2]
returns, advantages = self._compute_gae(rewards, values, dones, last_value)
# 计算损失
states_tensor = torch.FloatTensor(np.array(states)).to(MAMLConfig.device)
actions_tensor = torch.stack(actions)
mean, log_std, current_values = self.policy(states_tensor, params=fast_weights)
std = log_std.exp()
dist = torch.distributions.Normal(mean, std)
log_probs = dist.log_prob(actions_tensor).sum(-1)
# 策略损失
policy_loss = -(log_probs * advantages).mean()
# 价值损失
value_loss = F.mse_loss(current_values, returns)
# 熵正则化
entropy_loss = -dist.entropy().mean()
total_loss = policy_loss + 0.5 * value_loss + 0.01 * entropy_loss
# 计算梯度并更新快速权重
grads = torch.autograd.grad(total_loss, fast_weights.values(), create_graph=True, allow_unused=True)
for (name, param), grad in zip(fast_weights.items(), grads):
if grad is not None:
fast_weights[name] = param - MAMLConfig.adaptation_lr * grad
env.close()
return fast_weights
def _compute_gae(self, rewards, values, dones, last_value):
values = values + [last_value]
gae = 0
returns = []
advantages = []
for t in reversed(range(len(rewards))):
delta = rewards[t] + MAMLConfig.gamma * values[t+1] * (1 - dones[t]) - values[t]
gae = delta + MAMLConfig.gamma * MAMLConfig.tau * (1 - dones[t]) * gae
advantages.insert(0, gae)
returns.insert(0, advantages[0] + values[t])
advantages = torch.tensor(advantages, device=MAMLConfig.device, dtype=torch.float32)
returns = torch.tensor(returns, device=MAMLConfig.device, dtype=torch.float32)
# 标准化优势
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return returns, advantages
def meta_update(self, tasks):
meta_loss = 0
for task in tasks:
fast_weights = self.adapt_task(task, MAMLConfig.adaptation_steps)
# 在适应后的策略上收集轨迹
env = gym.make(MAMLConfig.env_name)
env.unwrapped.model.body_mass[:] = task['mass']
env.unwrapped.model.dof_damping[:] = task['damping']
states, actions, rewards, values, dones = [], [], [], [], []
obs, _ = env.reset()
done = False
while not done:
with torch.no_grad():
state_tensor = torch.FloatTensor(obs).to(MAMLConfig.device)
action, _ = self.policy.sample_action(state_tensor, params=fast_weights)
_, _, value = self.policy(state_tensor, params=fast_weights)
next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
states.append(obs)
actions.append(action)
rewards.append(reward)
values.append(value)
dones.append(terminated or truncated)
obs = next_obs
done = terminated or truncated
# 计算GAE和returns
with torch.no_grad():
last_value = self.policy(torch.FloatTensor(obs).to(MAMLConfig.device), params=fast_weights)[2]
returns, advantages = self._compute_gae(rewards, values, dones, last_value)
# 计算元损失
states_tensor = torch.FloatTensor(np.array(states)).to(MAMLConfig.device)
actions_tensor = torch.stack(actions).to(MAMLConfig.device)
mean, log_std, current_values = self.policy(states_tensor, params=fast_weights)
std = log_std.exp()
dist = torch.distributions.Normal(mean, std)
log_probs = dist.log_prob(actions_tensor).sum(-1)
policy_loss = -(log_probs * advantages).mean()
value_loss = F.mse_loss(current_values, returns)
entropy_loss = -dist.entropy().mean()
task_loss = policy_loss + 0.5 * value_loss + 0.01 * entropy_loss
meta_loss += task_loss
env.close()
meta_loss /= len(tasks)
self.meta_optimizer.zero_grad()
meta_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), MAMLConfig.clip_grad)
self.meta_optimizer.step()
return meta_loss.item()
def train(self):
for epoch in range(MAMLConfig.total_epochs):
batch_tasks = np.random.choice(self.tasks, MAMLConfig.meta_batch_size)
loss = self.meta_update(batch_tasks)
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1:04d} | Meta Loss: {loss:.1f}")
self._evaluate()
def _evaluate(self, num_tasks=3):
total_rewards = []
for i in range(num_tasks):
task = self.task_gen.sample_task()
original_params = {k: v.clone() for k, v in self.policy.named_parameters()}
fast_weights = self.adapt_task(task, MAMLConfig.adaptation_steps)
env = gym.make(MAMLConfig.env_name)
env.unwrapped.model.body_mass[:] = task['mass']
env.unwrapped.model.dof_damping[:] = task['damping']
obs, _ = env.reset()
total_reward = 0
done = False
while not done:
with torch.no_grad():
action, _ = self.policy.sample_action(
torch.FloatTensor(obs).to(MAMLConfig.device),
params=fast_weights
)
obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
total_reward += reward
done = terminated or truncated
total_rewards.append(total_reward)
self.policy.load_state_dict(original_params)
env.close()
avg_reward = sum(total_rewards) / num_tasks
print(f"Evaluation | Avg Reward: {avg_reward:.1f}")
if __name__ == "__main__":
start = time.time()
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
print(f"开始时间: {start_str}")
print("初始化环境...")
trainer = MAMLTrainer()
trainer.train()
end = time.time()
end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
print(f"训练完成时间: {end_str}")
print(f"训练完成,耗时: {end - start:.2f}秒")
四、关键代码解析
-
任务生成器
-
通过修改机器人质量和关节阻尼参数生成新任务
-
每个任务对应不同的物理动力学特性
-
-
双层优化实现
-
adapt_task
:内层循环在单个任务上执行策略梯度更新 -
meta_update
:外层循环跨任务更新初始参数
-
-
策略快速适应
-
使用
torch.autograd.grad
计算二阶梯度 -
通过参数克隆实现任务特定参数更新
-
五、训练输出示例
python
开始时间: 2025-03-19 12:49:54
初始化环境...
Epoch 0050 | Meta Loss: 18.0
Evaluation | Avg Reward: -299.6
Epoch 0100 | Meta Loss: 21.5
Evaluation | Avg Reward: -193.3
Epoch 0150 | Meta Loss: 14.8
Evaluation | Avg Reward: -199.7
Epoch 0200 | Meta Loss: 25.3
Evaluation | Avg Reward: -317.4
Epoch 0250 | Meta Loss: 16.7
Evaluation | Avg Reward: -174.8
Epoch 0300 | Meta Loss: 24.3
Evaluation | Avg Reward: -277.6
Epoch 0350 | Meta Loss: 12.3
Evaluation | Avg Reward: -249.0
Epoch 0400 | Meta Loss: 25.4
Evaluation | Avg Reward: -253.4
Epoch 0450 | Meta Loss: 13.6
Evaluation | Avg Reward: -222.1
Epoch 0500 | Meta Loss: 27.9
Evaluation | Avg Reward: -295.4
Epoch 0550 | Meta Loss: 23.3
Evaluation | Avg Reward: -484.5
Epoch 0600 | Meta Loss: 17.2
Evaluation | Avg Reward: -315.4
Epoch 0650 | Meta Loss: 16.0
Evaluation | Avg Reward: -250.3
Epoch 0700 | Meta Loss: 20.9
Evaluation | Avg Reward: -300.3
Epoch 0750 | Meta Loss: 33.4
Evaluation | Avg Reward: -305.0
Epoch 0800 | Meta Loss: 61.8
Evaluation | Avg Reward: -260.7
Epoch 0850 | Meta Loss: 10.9
Evaluation | Avg Reward: -311.5
Epoch 0900 | Meta Loss: 24.7
Evaluation | Avg Reward: -299.8
Epoch 0950 | Meta Loss: 14.5
Evaluation | Avg Reward: -321.9
Epoch 1000 | Meta Loss: 12.0
Evaluation | Avg Reward: -275.3
训练完成时间: 2025-03-20 09:28:03
训练完成,耗时: 74288.70秒
六、总结与扩展
本文实现了元强化学习的核心范式------MAML 算法,展示了策略快速适应新任务的能力。读者可尝试以下扩展方向:
-
高效探索策略 结合 Proximal Policy Optimization (PPO) 或 Soft Actor-Critic (SAC) 提升采样效率
-
多模态任务适应 使用条件策略网络处理离散任务类型
在下一篇文章中,我们将探索 多智能体强化学习(MARL),并实现 MADDPG 算法!
注意事项
-
安装依赖:
bashpip install gymnasium[mujoco] torch
-
完整训练需要 GPU 加速(推荐显存 ≥ 8GB)
-
若遇到环境初始化错误,检查 MuJoCo 许可证配置:
bashls ~/.mujoco/mjkey.txt