PyTorch 深度学习实战(21):元强化学习与 MAML 算法

一、元强化学习原理

1. 元学习核心思想

元强化学习(Meta-RL)旨在让智能体快速适应新任务 ,其核心是通过任务分布学习共享知识。与传统强化学习的区别在于:

对比维度 传统强化学习 元强化学习
目标 解决单一任务 快速适应任务分布中的新任务
训练方式 单任务大量交互 多任务交替训练
泛化能力 任务特定策略 跨任务可迁移策略
2. MAML 算法框架

Model-Agnostic Meta-Learning (MAML) 通过双层优化实现快速适应:

  1. 内层循环:在单个任务上执行少量梯度步

  2. 外层循环:跨任务更新初始参数

数学表达:


二、MAML 实现步骤(基于 Gymnasium)

我们将以 HalfCheetah 变体任务 为例,实现 MAML 算法:

  1. 定义任务分布:修改机器人质量参数生成不同任务

  2. 构建策略网络:基于 PyTorch 的 Actor-Critic 架构

  3. 实现双层优化:内层任务适配 + 外层元更新

  4. 快速适应测试:在新任务上验证策略性能


三、代码实现

python 复制代码
import gymnasium as gym
import torch
import numpy as np
from torch import nn, optim
from collections import deque
import time
import torch.nn.functional as F
​
# ================== 配置参数优化 ==================
class MAMLConfig:
    env_name = "HalfCheetah-v5"
    num_tasks = 20
    adaptation_steps = 10  # 增加适应步数
    adaptation_lr = 0.1  # 调整适应学习率
    hidden_dim = 256      # 增大隐藏层维度
    gamma = 0.99
    tau = 0.95           # 用于GAE计算
    meta_batch_size = 8   # 增大元批量
    meta_lr = 3e-4       # 调整元学习率
    total_epochs = 1000
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    clip_grad = 0.5      # 梯度裁剪阈值
​
# ================== 策略网络优化 ==================
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        # 独立特征提取层(修正结构命名)
        self.actor_net = nn.Sequential(
            nn.Linear(state_dim, MAMLConfig.hidden_dim),
            nn.LayerNorm(MAMLConfig.hidden_dim),
            nn.Tanh(),
            nn.Linear(MAMLConfig.hidden_dim, MAMLConfig.hidden_dim),
            nn.LayerNorm(MAMLConfig.hidden_dim),
            nn.Tanh()
        )
        self.critic_net = nn.Sequential(
            nn.Linear(state_dim, MAMLConfig.hidden_dim),
            nn.LayerNorm(MAMLConfig.hidden_dim),
            nn.Tanh(),
            nn.Linear(MAMLConfig.hidden_dim, MAMLConfig.hidden_dim),
            nn.LayerNorm(MAMLConfig.hidden_dim),
            nn.Tanh()
        )
        self.actor_mean = nn.Linear(MAMLConfig.hidden_dim, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
        self.critic = nn.Linear(MAMLConfig.hidden_dim, 1)
        
        # 初始化参数(保持原有初始化逻辑)
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=0.01)  # 正交初始化
                nn.init.constant_(m.bias, 0)
        # 策略最后一层初始化较小
        nn.init.orthogonal_(self.actor_mean.weight, gain=0.01)
        nn.init.constant_(self.actor_mean.bias, 0)
        # 价值头初始化
        nn.init.orthogonal_(self.critic.weight, gain=1.0)
        nn.init.constant_(self.critic.bias, 0)
    
    def forward(self, state, params=None):
        if params is None:
            # 正常前向传播
            actor_features = self.actor_net(state)
            critic_features = self.critic_net(state)
            mean = self.actor_mean(actor_features)
            value = self.critic(critic_features).squeeze(-1)
        else:
            # 手动参数计算时保持维度一致性
            if len(state.shape) == 1:
                state = state.unsqueeze(0)  # 添加批量维度
            # Actor网络计算
            x = F.linear(state, 
                       params['actor_net.0.weight'], 
                       params['actor_net.0.bias'])
            x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
            x = torch.tanh(x)
            x = F.linear(x, 
                       params['actor_net.3.weight'], 
                       params['actor_net.3.bias'])
            x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
            actor_features = torch.tanh(x)
            
            # Critic网络计算
            x = F.linear(state, 
                       params['critic_net.0.weight'], 
                       params['critic_net.0.bias'])
            x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
            x = torch.tanh(x)
            x = F.linear(x, 
                       params['critic_net.3.weight'], 
                       params['critic_net.3.bias'])
            x = F.layer_norm(x, (MAMLConfig.hidden_dim,))
            critic_features = torch.tanh(x)
            
            mean = F.linear(actor_features, 
                          params['actor_mean.weight'],
                          params['actor_mean.bias'])
            value = F.linear(critic_features,
                           params['critic.weight'],
                           params['critic.bias']).squeeze(-1)
        
        log_std = self.log_std.unsqueeze(0).expand(mean.shape[0], -1)
        return mean, log_std, value
​
    def sample_action(self, state, params=None):
        mean, log_std, _ = self.forward(state, params)
        std = log_std.exp()
        dist = torch.distributions.Normal(mean, std)
        action = dist.rsample()
        
        # 新增维度检查逻辑
        if len(action.shape) > 1:
            if action.shape[0] == 1:  # 单样本批量情况
                action = action.squeeze(0)
            else:                     # 多步采样情况
                action = action.squeeze()
        
        log_prob = dist.log_prob(action).sum(-1)
        return action.detach(), log_prob
​
# ================== 任务生成器优化 ==================
class TaskGenerator:
    def __init__(self):
        self.default_params = self._get_default_params()
    
    def _get_default_params(self):
        env = gym.make(MAMLConfig.env_name)
        params = {
            'mass': env.unwrapped.model.body_mass.copy(),
            'damping': env.unwrapped.model.dof_damping.copy()
        }
        env.close()
        return params
    
    def sample_task(self):
        new_params = {
            'mass': self.default_params['mass'] * np.random.uniform(0.5, 2.0, size=self.default_params['mass'].shape),
            'damping': self.default_params['damping'] * np.random.uniform(0.5, 2.0, size=self.default_params['damping'].shape),
            'ctrlrange': self.default_params['damping'] * np.random.uniform(0.8, 1.2)  # 新增控制力范围扰动
        }
        return new_params
​
# ================== MAML 训练系统优化 ==================
class MAMLTrainer:
    def __init__(self):
        self.env = gym.make(MAMLConfig.env_name)
        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.shape[0]
        self.policy = ActorCritic(self.state_dim, self.action_dim).to(MAMLConfig.device)
        self.meta_optimizer = optim.Adam(self.policy.parameters(), lr=MAMLConfig.meta_lr, betas=(0.9, 0.999))
        self.task_gen = TaskGenerator()
        self.tasks = [self.task_gen.sample_task() for _ in range(MAMLConfig.num_tasks)]
    
    def adapt_task(self, task_params, num_steps):
        env = gym.make(MAMLConfig.env_name)
        env.unwrapped.model.body_mass[:] = task_params['mass']
        env.unwrapped.model.dof_damping[:] = task_params['damping']
        
        fast_weights = {k: v.clone().requires_grad_(True) for k, v in self.policy.named_parameters()}
        
        # 多步适应过程
        for step in range(num_steps):
            states, actions, rewards, values, dones = [], [], [], [], []
            obs, _ = env.reset()
            done = False
            while not done:
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(obs).to(MAMLConfig.device)
                    action, _ = self.policy.sample_action(state_tensor, params=fast_weights)
                    _, _, value = self.policy(state_tensor, params=fast_weights)
                # next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
                next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy().astype(np.float32).flatten()  # 新增flatten()
)
                
                states.append(obs)
                actions.append(action)
                rewards.append(reward)
                values.append(value)
                dones.append(terminated or truncated)
                
                obs = next_obs
                done = terminated or truncated
            
            # 计算GAE
            with torch.no_grad():
                last_value = self.policy(torch.FloatTensor(obs).to(MAMLConfig.device), params=fast_weights)[2]
                returns, advantages = self._compute_gae(rewards, values, dones, last_value)
            
            # 计算损失
            states_tensor = torch.FloatTensor(np.array(states)).to(MAMLConfig.device)
            actions_tensor = torch.stack(actions)
            
            mean, log_std, current_values = self.policy(states_tensor, params=fast_weights)
            std = log_std.exp()
            dist = torch.distributions.Normal(mean, std)
            log_probs = dist.log_prob(actions_tensor).sum(-1)
            
            # 策略损失
            policy_loss = -(log_probs * advantages).mean()
            # 价值损失
            value_loss = F.mse_loss(current_values, returns)
            # 熵正则化
            entropy_loss = -dist.entropy().mean()
            
            total_loss = policy_loss + 0.5 * value_loss + 0.01 * entropy_loss
            
            # 计算梯度并更新快速权重
            grads = torch.autograd.grad(total_loss, fast_weights.values(), create_graph=True, allow_unused=True)
            for (name, param), grad in zip(fast_weights.items(), grads):
                if grad is not None:
                    fast_weights[name] = param - MAMLConfig.adaptation_lr * grad
        
        env.close()
        return fast_weights
    
    def _compute_gae(self, rewards, values, dones, last_value):
        values = values + [last_value]
        gae = 0
        returns = []
        advantages = []
        
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + MAMLConfig.gamma * values[t+1] * (1 - dones[t]) - values[t]
            gae = delta + MAMLConfig.gamma * MAMLConfig.tau * (1 - dones[t]) * gae
            advantages.insert(0, gae)
            returns.insert(0, advantages[0] + values[t])
        
        advantages = torch.tensor(advantages, device=MAMLConfig.device, dtype=torch.float32)
        returns = torch.tensor(returns, device=MAMLConfig.device, dtype=torch.float32)
        # 标准化优势
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        return returns, advantages
    
    def meta_update(self, tasks):
        meta_loss = 0
        for task in tasks:
            fast_weights = self.adapt_task(task, MAMLConfig.adaptation_steps)
            
            # 在适应后的策略上收集轨迹
            env = gym.make(MAMLConfig.env_name)
            env.unwrapped.model.body_mass[:] = task['mass']
            env.unwrapped.model.dof_damping[:] = task['damping']
            
            states, actions, rewards, values, dones = [], [], [], [], []
            obs, _ = env.reset()
            done = False
            while not done:
                with torch.no_grad():
                    state_tensor = torch.FloatTensor(obs).to(MAMLConfig.device)
                    action, _ = self.policy.sample_action(state_tensor, params=fast_weights)
                    _, _, value = self.policy(state_tensor, params=fast_weights)
                next_obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
                
                states.append(obs)
                actions.append(action)
                rewards.append(reward)
                values.append(value)
                dones.append(terminated or truncated)
                obs = next_obs
                done = terminated or truncated
            
            # 计算GAE和returns
            with torch.no_grad():
                last_value = self.policy(torch.FloatTensor(obs).to(MAMLConfig.device), params=fast_weights)[2]
                returns, advantages = self._compute_gae(rewards, values, dones, last_value)
            
            # 计算元损失
            states_tensor = torch.FloatTensor(np.array(states)).to(MAMLConfig.device)
            actions_tensor = torch.stack(actions).to(MAMLConfig.device)
            
            mean, log_std, current_values = self.policy(states_tensor, params=fast_weights)
            std = log_std.exp()
            dist = torch.distributions.Normal(mean, std)
            log_probs = dist.log_prob(actions_tensor).sum(-1)
            
            policy_loss = -(log_probs * advantages).mean()
            value_loss = F.mse_loss(current_values, returns)
            entropy_loss = -dist.entropy().mean()
            
            task_loss = policy_loss + 0.5 * value_loss + 0.01 * entropy_loss
            meta_loss += task_loss
            
            env.close()
        
        meta_loss /= len(tasks)
        self.meta_optimizer.zero_grad()
        meta_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), MAMLConfig.clip_grad)
        self.meta_optimizer.step()
        return meta_loss.item()
    
    def train(self):
        for epoch in range(MAMLConfig.total_epochs):
            batch_tasks = np.random.choice(self.tasks, MAMLConfig.meta_batch_size)
            loss = self.meta_update(batch_tasks)
            if (epoch + 1) % 50 == 0:
                print(f"Epoch {epoch+1:04d} | Meta Loss: {loss:.1f}")
                self._evaluate()
​
    def _evaluate(self, num_tasks=3):
        total_rewards = []
        for i in range(num_tasks):
            task = self.task_gen.sample_task()
            original_params = {k: v.clone() for k, v in self.policy.named_parameters()}
            fast_weights = self.adapt_task(task, MAMLConfig.adaptation_steps)
            env = gym.make(MAMLConfig.env_name)
            env.unwrapped.model.body_mass[:] = task['mass']
            env.unwrapped.model.dof_damping[:] = task['damping']
            obs, _ = env.reset()
            total_reward = 0
            done = False
            while not done:
                with torch.no_grad():
                    action, _ = self.policy.sample_action(
                        torch.FloatTensor(obs).to(MAMLConfig.device),
                        params=fast_weights
                    )
                obs, reward, terminated, truncated, _ = env.step(action.cpu().numpy())
                total_reward += reward
                done = terminated or truncated
            total_rewards.append(total_reward)
            self.policy.load_state_dict(original_params)
            env.close()
        avg_reward = sum(total_rewards) / num_tasks
        print(f"Evaluation | Avg Reward: {avg_reward:.1f}")
​
if __name__ == "__main__":
    start = time.time()
    start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
    print(f"开始时间: {start_str}")
    print("初始化环境...")
    trainer = MAMLTrainer()
    trainer.train()
    end = time.time()
    end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
    print(f"训练完成时间: {end_str}")
    print(f"训练完成,耗时: {end - start:.2f}秒")

四、关键代码解析

  1. 任务生成器

    • 通过修改机器人质量和关节阻尼参数生成新任务

    • 每个任务对应不同的物理动力学特性

  2. 双层优化实现

    • adapt_task:内层循环在单个任务上执行策略梯度更新

    • meta_update:外层循环跨任务更新初始参数

  3. 策略快速适应

    • 使用 torch.autograd.grad 计算二阶梯度

    • 通过参数克隆实现任务特定参数更新


五、训练输出示例

python 复制代码
开始时间: 2025-03-19 12:49:54
初始化环境...
Epoch 0050 | Meta Loss: 18.0
Evaluation | Avg Reward: -299.6
Epoch 0100 | Meta Loss: 21.5
Evaluation | Avg Reward: -193.3
Epoch 0150 | Meta Loss: 14.8
Evaluation | Avg Reward: -199.7
Epoch 0200 | Meta Loss: 25.3
Evaluation | Avg Reward: -317.4
Epoch 0250 | Meta Loss: 16.7
Evaluation | Avg Reward: -174.8
Epoch 0300 | Meta Loss: 24.3
Evaluation | Avg Reward: -277.6
Epoch 0350 | Meta Loss: 12.3
Evaluation | Avg Reward: -249.0
Epoch 0400 | Meta Loss: 25.4
Evaluation | Avg Reward: -253.4
Epoch 0450 | Meta Loss: 13.6
Evaluation | Avg Reward: -222.1
Epoch 0500 | Meta Loss: 27.9
Evaluation | Avg Reward: -295.4
Epoch 0550 | Meta Loss: 23.3
Evaluation | Avg Reward: -484.5
Epoch 0600 | Meta Loss: 17.2
Evaluation | Avg Reward: -315.4
Epoch 0650 | Meta Loss: 16.0
Evaluation | Avg Reward: -250.3
Epoch 0700 | Meta Loss: 20.9
Evaluation | Avg Reward: -300.3
Epoch 0750 | Meta Loss: 33.4
Evaluation | Avg Reward: -305.0
Epoch 0800 | Meta Loss: 61.8
Evaluation | Avg Reward: -260.7
Epoch 0850 | Meta Loss: 10.9
Evaluation | Avg Reward: -311.5
Epoch 0900 | Meta Loss: 24.7
Evaluation | Avg Reward: -299.8
Epoch 0950 | Meta Loss: 14.5
Evaluation | Avg Reward: -321.9
Epoch 1000 | Meta Loss: 12.0
Evaluation | Avg Reward: -275.3
训练完成时间: 2025-03-20 09:28:03
训练完成,耗时: 74288.70秒

六、总结与扩展

本文实现了元强化学习的核心范式------MAML 算法,展示了策略快速适应新任务的能力。读者可尝试以下扩展方向:

  1. 高效探索策略 结合 Proximal Policy Optimization (PPO) 或 Soft Actor-Critic (SAC) 提升采样效率

  2. 多模态任务适应 使用条件策略网络处理离散任务类型

在下一篇文章中,我们将探索 多智能体强化学习(MARL),并实现 MADDPG 算法!


注意事项

  1. 安装依赖:

    bash 复制代码
    pip install gymnasium[mujoco] torch
  2. 完整训练需要 GPU 加速(推荐显存 ≥ 8GB)

  3. 若遇到环境初始化错误,检查 MuJoCo 许可证配置:

    bash 复制代码
    ls ~/.mujoco/mjkey.txt
相关推荐
暮雨哀尘几秒前
微信小程序开发:微信小程序组件应用研究
算法·微信·微信小程序·小程序·notepad++·微信公众平台·组件
dokii122 分钟前
leetcode199 二叉树的右视图
数据结构·算法·leetcode
UP_Continue35 分钟前
排序--归并排序--非递归
数据结构·算法·排序算法
花果山-马大帅37 分钟前
我的机器学习学习之路
人工智能·python·算法·机器学习·scikit-learn
坤小满学Java38 分钟前
【力扣刷题|第十七天】0-1 背包 完全背包
算法·leetcode
陈陈爱java1 小时前
Java算法模板
java·开发语言·算法
生信碱移2 小时前
细胞内与细胞间网络整合分析!神经网络+细胞通讯,这个单细胞分析工具一箭双雕了(scTenifoldXct)
人工智能·经验分享·深度学习·神经网络·机器学习·数据分析·数据可视化
ahahahahaha23332 小时前
相似度计算 ccf-csp 2024-2-2
数据结构·c++·算法
2402_881319302 小时前
3.28学习总结
数据结构·学习·算法