PyTorch 深度学习实战(23):多任务强化学习(Multi-Task RL)

一、多任务强化学习原理

1. 多任务学习核心思想

多任务强化学习(Multi-Task RL)旨在让智能体同时学习多个任务 ,通过共享知识提升学习效率和泛化能力。与单任务强化学习的区别在于:

对比维度 单任务强化学习 多任务强化学习
目标 优化单一任务策略 同时优化多个任务的共享策略
训练方式 单任务独立训练 多任务联合训练
知识迁移 共享表示或参数实现跨任务知识迁移
应用场景 任务特定场景 复杂环境中的通用智能体
2. 基于共享表示的多任务框架

通过共享网络层 学习任务共性,任务特定层处理任务差异。算法流程如下:

  1. 任务采样:从任务分布中随机选择一个任务

  2. 策略执行:基于共享网络生成动作

  3. 梯度更新:联合优化共享参数和任务特定参数

数学表达:


二、多任务 PPO 算法实现(基于 Gymnasium)

我们将以 Meta-World 多任务机械臂环境 为例,实现基于 PPO 的多任务强化学习:

  1. 定义任务集合 :包含 reachpushpick-place 等任务

  2. 构建共享策略网络:共享卷积层 + 任务特定全连接层

  3. 实现多任务采样:动态切换任务训练

  4. 联合梯度更新:平衡多任务损失


三、代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
from torch.cuda.amp import autocast, GradScaler
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
from collections import deque
​
# ================== 配置参数 ==================
class MultiTaskPPOConfig:
    task_names = [
        'reach-v2-goal-observable',
        'push-v2-goal-observable',
        'pick-place-v2-goal-observable'
    ]
    num_tasks = 3
    hidden_dim = 512
    task_specific_dim = 128
    lr = 3e-4
    gamma = 0.99
    gae_lambda = 0.95
    clip_epsilon = 0.2
    ppo_epochs = 4
    batch_size = 512
    max_episodes = 2000
    max_steps = 500
    grad_clip = 0.5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== 共享策略网络 ==================
class SharedPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.action_dim = action_dim
        self.shared_net = nn.Sequential(
            nn.Linear(state_dim, MultiTaskPPOConfig.hidden_dim),
            nn.LayerNorm(MultiTaskPPOConfig.hidden_dim),
            nn.GELU(),
            nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.hidden_dim),
            nn.GELU()
        )
        
        # 多任务头部
        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
                nn.GELU(),
                nn.Linear(MultiTaskPPOConfig.task_specific_dim, action_dim)
            ) for _ in range(MultiTaskPPOConfig.num_tasks)
        ])
        
        self.value_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
                nn.GELU(),
                nn.Linear(MultiTaskPPOConfig.task_specific_dim, 1)
            ) for _ in range(MultiTaskPPOConfig.num_tasks)
        ])
​
    def forward(self, states, task_ids):
        shared_features = self.shared_net(states)
        batch_size = states.size(0)
        
        # 初始化与输入相同dtype的输出张量
        action_means = torch.zeros_like(
            states[:, :self.action_dim],  # 假设states维度足够
            dtype=states.dtype, 
            device=states.device
        )
        values = torch.zeros(
            batch_size, 1, 
            dtype=states.dtype, 
            device=states.device
        )
        
        unique_task_ids = torch.unique(task_ids)
        
        for task_id_tensor in unique_task_ids:
            task_id = task_id_tensor.item()
            mask = (task_ids == task_id_tensor)
            
            if not mask.any():
                continue
                
            selected_features = shared_features[mask]
            
            # 显式转换输出类型到states.dtype (通常是float32)
            task_action = self.task_heads[task_id](selected_features).to(dtype=states.dtype)
            task_value = self.value_heads[task_id](selected_features).to(dtype=states.dtype)
            
            action_means[mask] = task_action
            values[mask] = task_value
            
        return action_means, values
​
# ================== 训练系统 ==================
class MultiTaskPPOTrainer:
    def __init__(self):
        # 初始化多任务环境
        self.envs = []
        self.state_dim = None
        self.action_dim = None
        
        # 验证环境并获取维度
        for task_name in MultiTaskPPOConfig.task_names:
            env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[task_name]()
            obs, _ = env.reset()
            
            if self.state_dim is None:
                self.state_dim = obs.shape[0]
                self.action_dim = env.action_space.shape[0]
            else:
                assert obs.shape[0] == self.state_dim, f"状态维度不一致: {task_name}"
                
            self.envs.append(env)
        
        # 初始化策略网络
        self.policy = SharedPolicy(self.state_dim, self.action_dim).to(MultiTaskPPOConfig.device)
        self.optimizer = optim.AdamW(self.policy.parameters(), lr=MultiTaskPPOConfig.lr)
        self.scaler = GradScaler()
        
        # 初始化经验回放缓冲
        self.buffer = deque(maxlen=MultiTaskPPOConfig.max_steps)
​
    def collect_experience(self, num_steps):
        """并行收集多任务经验"""
        for _ in range(num_steps):
            task_id = int(np.random.randint(MultiTaskPPOConfig.num_tasks))
            env = self.envs[task_id]
            
            if not hasattr(env, '_last_obs'):
                state, _ = env.reset()
            else:
                state = env._last_obs
                
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(MultiTaskPPOConfig.device)
                # 将task_id转换为张量
                task_id_tensor = torch.tensor([task_id], dtype=torch.long, device=MultiTaskPPOConfig.device)
                action_mean, value = self.policy(state_tensor, task_id_tensor)
                dist = Normal(action_mean, torch.ones_like(action_mean))
                action = dist.sample().squeeze(0).cpu().numpy()
                log_prob = dist.log_prob(action_mean).detach()
            
            next_state, reward, done, trunc, _ = env.step(action)
            self.buffer.append({
                'state': state,
                'action': action,
                'log_prob': log_prob.cpu(),
                'reward': float(reward),
                'done': bool(done),
                'task_id': task_id,
                'value': float(value.item())
            })
            
            state = next_state if not (done or trunc) else env.reset()[0]
​
    def compute_gae(self, values, rewards, dones):
        """计算广义优势估计(GAE)"""
        advantages = []
        last_advantage = 0
        next_value = 0
        
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + MultiTaskPPOConfig.gamma * next_value * (1 - dones[t]) - values[t]
            last_advantage = delta + MultiTaskPPOConfig.gamma * MultiTaskPPOConfig.gae_lambda * (1 - dones[t]) * last_advantage
            advantages.append(last_advantage)
            next_value = values[t]
            
        advantages = torch.tensor(advantages[::-1], dtype=torch.float32).to(MultiTaskPPOConfig.device)
        returns = advantages + torch.tensor(values, dtype=torch.float32).to(MultiTaskPPOConfig.device)
        return (advantages - advantages.mean()) / (advantages.std() + 1e-8), returns
​
    def update_policy(self):
        """策略更新阶段正确转换张量"""
        if not self.buffer:
            return 0, 0
        
        """使用PPO进行策略优化"""
        # 从缓冲中提取数据
        batch = list(self.buffer)
        states = torch.tensor(
            [x['state'] for x in batch],
            dtype=torch.float32,
            device=MultiTaskPPOConfig.device
        )
        actions = torch.FloatTensor(np.array([x['action'] for x in batch])).to(MultiTaskPPOConfig.device)
        old_log_probs = torch.cat([x['log_prob'] for x in batch]).to(MultiTaskPPOConfig.device)
        rewards = torch.FloatTensor([x['reward'] for x in batch]).to(MultiTaskPPOConfig.device)
        dones = torch.FloatTensor([x['done'] for x in batch]).to(MultiTaskPPOConfig.device)
        task_ids = torch.tensor(
            [x['task_id'] for x in batch],
            dtype=torch.long,  # 必须指定为long类型
            device=MultiTaskPPOConfig.device
        )
        values = torch.FloatTensor([x['value'] for x in batch]).to(MultiTaskPPOConfig.device)
​
        # 计算GAE和returns
        advantages, returns = self.compute_gae(values.cpu().numpy(), rewards.cpu().numpy(), dones.cpu().numpy())
​
        # 自动混合精度训练
        with autocast():
            total_policy_loss = 0
            total_value_loss = 0
            
            for _ in range(MultiTaskPPOConfig.ppo_epochs):
                # 随机打乱数据
                perm = torch.randperm(len(batch))
                
                for i in range(0, len(batch), MultiTaskPPOConfig.batch_size):
                    idx = perm[i:i+MultiTaskPPOConfig.batch_size]
                    
                    # 获取小批量数据
                    batch_states = states[idx]
                    batch_actions = actions[idx]
                    batch_old_log_probs = old_log_probs[idx]
                    batch_returns = returns[idx]
                    batch_advantages = advantages[idx]
                    batch_task_ids = task_ids[idx]
                    
                    # 前向传播
                    action_means, new_values = self.policy(states, task_ids)
                    dist = Normal(action_means, torch.ones_like(action_means))
                    new_log_probs = dist.log_prob(batch_actions)
                    
                    # 计算重要性采样比率
                    ratio = (new_log_probs - batch_old_log_probs).exp()
                    
                    # 策略损失
                    surr1 = ratio * batch_advantages.unsqueeze(-1)
                    surr2 = torch.clamp(ratio, 1-MultiTaskPPOConfig.clip_epsilon, 
                                      1+MultiTaskPPOConfig.clip_epsilon) * batch_advantages.unsqueeze(-1)
                    policy_loss = -torch.min(surr1, surr2).mean()
                    
                    # 值函数损失
                    value_loss = 0.5 * (new_values.squeeze() - batch_returns).pow(2).mean()
                    
                    # 总损失
                    loss = policy_loss + value_loss
                    
                    # 反向传播
                    self.scaler.scale(loss).backward()
                    total_policy_loss += policy_loss.item()
                    total_value_loss += value_loss.item()
​
            # 梯度裁剪和参数更新
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), MultiTaskPPOConfig.grad_clip)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
​
        return total_policy_loss / MultiTaskPPOConfig.ppo_epochs, total_value_loss / MultiTaskPPOConfig.ppo_epochs
​
    def train(self):
        print(f"开始训练,设备:{MultiTaskPPOConfig.device}")
        start_time = time.time()
        episode_rewards = {i: deque(maxlen=100) for i in range(MultiTaskPPOConfig.num_tasks)}
        
        for episode in range(MultiTaskPPOConfig.max_episodes):
            # 经验收集阶段
            self.collect_experience(MultiTaskPPOConfig.max_steps)
            
            # 策略优化阶段
            policy_loss, value_loss = self.update_policy()
            
            # 记录统计信息
            task_id = np.random.randint(MultiTaskPPOConfig.num_tasks)
            episode_reward = sum(x['reward'] for x in self.buffer if x['task_id'] == task_id)
            episode_rewards[task_id].append(episode_reward)
            
            # 定期输出日志
            if (episode + 1) % 100 == 0:
                avg_rewards = {k: np.mean(v) if v else 0 for k, v in episode_rewards.items()}
                time_cost = time.time() - start_time
                print(f"Episode {episode+1:5d} | Time: {time_cost:6.1f}s")
                for task_id in range(MultiTaskPPOConfig.num_tasks):
                    task_name = MultiTaskPPOConfig.task_names[task_id]
                    print(f"  {task_name:25s} | Avg Reward: {avg_rewards[task_id]:7.2f}")
                print(f"  Policy Loss: {policy_loss:.4f} | Value Loss: {value_loss:.4f}\n")
                start_time = time.time()
​
if __name__ == "__main__":
    trainer = MultiTaskPPOTrainer()
    print(f"状态维度: {trainer.state_dim}, 动作维度: {trainer.action_dim}")
    trainer.train()

四、关键代码解析

  1. 共享策略网络

    • SharedPolicy 包含共享网络层和任务特定头部

    • task_headsvalue_heads 分别处理不同任务的动作和值函数

  2. 多任务采样机制

    • 每个回合随机选择一个任务进行训练

    • 动态切换环境实例 env = self.envs[task_id]

  3. 联合梯度更新

    • 计算多任务的策略损失和值函数损失

    • 通过 task_id 索引选择对应任务头部参数


五、训练输出示例

python 复制代码
状态维度: 39, 动作维度: 4
开始训练,设备:cuda
/workspace/e23.py:184: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)
  states = torch.tensor(
/workspace/e23.py:204: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast():
Episode   100 | Time:  931.2s
  reach-v2-goal-observable  | Avg Reward:  226.83
  push-v2-goal-observable   | Avg Reward:    8.82
  pick-place-v2-goal-observable | Avg Reward:    3.31
  Policy Loss: 0.0386 | Value Loss: 13.2587
​
Episode   200 | Time:  935.3s
  reach-v2-goal-observable  | Avg Reward:  227.12
  push-v2-goal-observable   | Avg Reward:    8.83
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0434 | Value Loss: 14.9413
​
Episode   300 | Time:  939.4s
  reach-v2-goal-observable  | Avg Reward:  226.78
  push-v2-goal-observable   | Avg Reward:    8.82
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0429 | Value Loss: 13.9076
​
Episode   400 | Time:  938.4s
  reach-v2-goal-observable  | Avg Reward:  225.74
  push-v2-goal-observable   | Avg Reward:    8.84
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0378 | Value Loss: 14.7157
​
Episode   500 | Time:  938.4s
  reach-v2-goal-observable  | Avg Reward:  225.45
  push-v2-goal-observable   | Avg Reward:    8.81
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0381 | Value Loss: 11.7940
​
Episode   600 | Time:  928.5s
  reach-v2-goal-observable  | Avg Reward:  225.39
  push-v2-goal-observable   | Avg Reward:    8.75
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0462 | Value Loss: 14.5566
​
Episode   700 | Time:  926.6s
  reach-v2-goal-observable  | Avg Reward:  226.37
  push-v2-goal-observable   | Avg Reward:    8.65
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0394 | Value Loss: 15.5556
​
Episode   800 | Time:  943.8s
  reach-v2-goal-observable  | Avg Reward:  224.72
  push-v2-goal-observable   | Avg Reward:    8.64
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0361 | Value Loss: 16.0126
​
Episode   900 | Time:  937.2s
  reach-v2-goal-observable  | Avg Reward:  224.15
  push-v2-goal-observable   | Avg Reward:    8.72
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0417 | Value Loss: 14.1907
​
Episode  1000 | Time:  940.7s
  reach-v2-goal-observable  | Avg Reward:  223.77
  push-v2-goal-observable   | Avg Reward:    8.73
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0399 | Value Loss: 16.0540
​
Episode  1100 | Time:  937.0s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.68
  pick-place-v2-goal-observable | Avg Reward:    3.17
  Policy Loss: 0.0409 | Value Loss: 15.5525
​
Episode  1200 | Time:  933.0s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.68
  pick-place-v2-goal-observable | Avg Reward:    3.17
  Policy Loss: 0.0388 | Value Loss: 17.4549
​
Episode  1300 | Time:  942.1s
  reach-v2-goal-observable  | Avg Reward:  224.35
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0447 | Value Loss: 14.6700
​
Episode  1400 | Time:  966.6s
  reach-v2-goal-observable  | Avg Reward:  224.27
  push-v2-goal-observable   | Avg Reward:    8.73
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0434 | Value Loss: 13.3487
​
Episode  1500 | Time:  943.0s
  reach-v2-goal-observable  | Avg Reward:  223.03
  push-v2-goal-observable   | Avg Reward:    8.69
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0438 | Value Loss: 14.7557
​
Episode  1600 | Time:  929.1s
  reach-v2-goal-observable  | Avg Reward:  224.01
  push-v2-goal-observable   | Avg Reward:    8.69
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0365 | Value Loss: 12.2506
​
Episode  1700 | Time:  937.9s
  reach-v2-goal-observable  | Avg Reward:  222.88
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0365 | Value Loss: 11.8954
​
Episode  1800 | Time:  930.1s
  reach-v2-goal-observable  | Avg Reward:  224.42
  push-v2-goal-observable   | Avg Reward:    8.75
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0437 | Value Loss: 13.6396
​
Episode  1900 | Time:  927.0s
  reach-v2-goal-observable  | Avg Reward:  224.66
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0360 | Value Loss: 14.3216
​
Episode  2000 | Time:  934.3s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.63
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0475 | Value Loss: 14.0712

六、总结与扩展

本文实现了多任务强化学习的核心范式------基于共享策略的 PPO 算法,展示了跨任务知识迁移的能力。读者可尝试以下扩展方向:

  1. 动态任务权重 根据任务难度自适应调整损失权重:

    python 复制代码
    # 在 update() 中添加任务权重
    task_weights = calculate_task_difficulty()
    loss = sum([weight * loss_i for weight, loss_i in zip(task_weights, losses)])
  2. 分层强化学习 引入高层策略调度任务:

    python 复制代码
    class MetaController(nn.Module):
        def __init__(self, num_tasks):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.ReLU(),
                nn.Linear(64, num_tasks)
            )
  3. 课程学习 从简单任务逐步过渡到复杂任务:

    python 复制代码
    def schedule_task():
        if episode < 1000:
            return 'reach-v2-goal-observable'
        elif episode < 2000:
            return 'push-v2-goal-observable'
        else:
            return 'pick-place-v2-goal-observable'

在下一篇文章中,我们将探索 分层强化学习(HRL),并实现 Option-Critic 算法!


注意事项

1.安装依赖:

bash 复制代码
pip install metaworld gymnasium torch

2.metaworld问题:

如果稳定版存在问题,尝试安装GitHub上的最新版:

bash 复制代码
pip install git+https://github.com/rlworkgroup/metaworld.git@master
相关推荐
白熊18844 分钟前
【计算机视觉】CV实战项目- COVID 社交距离检测(covid-social-distancing-detection)
人工智能·opencv·计算机视觉
QQ_7781329742 小时前
Crawl4AI:重塑大语言模型数据供给的开源革命者
人工智能
(initial)3 小时前
第八章:探索新兴趋势:Agent 框架、产品与开源力量
人工智能·agent
美亚特直线轴承3 小时前
直线轴承在自动化机械设备中的应用
运维·人工智能·经验分享·笔记·机器人·自动化·制造
cosX+sinY5 小时前
1. ubuntu20.04 终端实现 ros的输出 (C++,Python)
人工智能·机器人·自动驾驶
乌旭6 小时前
边缘计算场景下的模型轻量化:TensorRT部署YOLOv7的端到端优化指南
人工智能·深度学习·yolo·transformer·边缘计算·gpu算力
果冻人工智能6 小时前
让未来重现《星际迷航》
人工智能
风口猪炒股指标6 小时前
2025-4-19 情绪周期视角复盘(mini)
人工智能·博弈论·群体博弈·人生哲学
訾博ZiBo6 小时前
AI日报 - 2024年04月22日
人工智能
啊哈哈哈哈哈啊哈哈6 小时前
R4打卡——tensorflow实现火灾预测
人工智能·python·tensorflow