PyTorch 深度学习实战(23):多任务强化学习(Multi-Task RL)

一、多任务强化学习原理

1. 多任务学习核心思想

多任务强化学习(Multi-Task RL)旨在让智能体同时学习多个任务 ,通过共享知识提升学习效率和泛化能力。与单任务强化学习的区别在于:

对比维度 单任务强化学习 多任务强化学习
目标 优化单一任务策略 同时优化多个任务的共享策略
训练方式 单任务独立训练 多任务联合训练
知识迁移 共享表示或参数实现跨任务知识迁移
应用场景 任务特定场景 复杂环境中的通用智能体
2. 基于共享表示的多任务框架

通过共享网络层 学习任务共性,任务特定层处理任务差异。算法流程如下:

  1. 任务采样:从任务分布中随机选择一个任务

  2. 策略执行:基于共享网络生成动作

  3. 梯度更新:联合优化共享参数和任务特定参数

数学表达:


二、多任务 PPO 算法实现(基于 Gymnasium)

我们将以 Meta-World 多任务机械臂环境 为例,实现基于 PPO 的多任务强化学习:

  1. 定义任务集合 :包含 reachpushpick-place 等任务

  2. 构建共享策略网络:共享卷积层 + 任务特定全连接层

  3. 实现多任务采样:动态切换任务训练

  4. 联合梯度更新:平衡多任务损失


三、代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
from torch.cuda.amp import autocast, GradScaler
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
from collections import deque
​
# ================== 配置参数 ==================
class MultiTaskPPOConfig:
    task_names = [
        'reach-v2-goal-observable',
        'push-v2-goal-observable',
        'pick-place-v2-goal-observable'
    ]
    num_tasks = 3
    hidden_dim = 512
    task_specific_dim = 128
    lr = 3e-4
    gamma = 0.99
    gae_lambda = 0.95
    clip_epsilon = 0.2
    ppo_epochs = 4
    batch_size = 512
    max_episodes = 2000
    max_steps = 500
    grad_clip = 0.5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== 共享策略网络 ==================
class SharedPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.action_dim = action_dim
        self.shared_net = nn.Sequential(
            nn.Linear(state_dim, MultiTaskPPOConfig.hidden_dim),
            nn.LayerNorm(MultiTaskPPOConfig.hidden_dim),
            nn.GELU(),
            nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.hidden_dim),
            nn.GELU()
        )
        
        # 多任务头部
        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
                nn.GELU(),
                nn.Linear(MultiTaskPPOConfig.task_specific_dim, action_dim)
            ) for _ in range(MultiTaskPPOConfig.num_tasks)
        ])
        
        self.value_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
                nn.GELU(),
                nn.Linear(MultiTaskPPOConfig.task_specific_dim, 1)
            ) for _ in range(MultiTaskPPOConfig.num_tasks)
        ])
​
    def forward(self, states, task_ids):
        shared_features = self.shared_net(states)
        batch_size = states.size(0)
        
        # 初始化与输入相同dtype的输出张量
        action_means = torch.zeros_like(
            states[:, :self.action_dim],  # 假设states维度足够
            dtype=states.dtype, 
            device=states.device
        )
        values = torch.zeros(
            batch_size, 1, 
            dtype=states.dtype, 
            device=states.device
        )
        
        unique_task_ids = torch.unique(task_ids)
        
        for task_id_tensor in unique_task_ids:
            task_id = task_id_tensor.item()
            mask = (task_ids == task_id_tensor)
            
            if not mask.any():
                continue
                
            selected_features = shared_features[mask]
            
            # 显式转换输出类型到states.dtype (通常是float32)
            task_action = self.task_heads[task_id](selected_features).to(dtype=states.dtype)
            task_value = self.value_heads[task_id](selected_features).to(dtype=states.dtype)
            
            action_means[mask] = task_action
            values[mask] = task_value
            
        return action_means, values
​
# ================== 训练系统 ==================
class MultiTaskPPOTrainer:
    def __init__(self):
        # 初始化多任务环境
        self.envs = []
        self.state_dim = None
        self.action_dim = None
        
        # 验证环境并获取维度
        for task_name in MultiTaskPPOConfig.task_names:
            env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[task_name]()
            obs, _ = env.reset()
            
            if self.state_dim is None:
                self.state_dim = obs.shape[0]
                self.action_dim = env.action_space.shape[0]
            else:
                assert obs.shape[0] == self.state_dim, f"状态维度不一致: {task_name}"
                
            self.envs.append(env)
        
        # 初始化策略网络
        self.policy = SharedPolicy(self.state_dim, self.action_dim).to(MultiTaskPPOConfig.device)
        self.optimizer = optim.AdamW(self.policy.parameters(), lr=MultiTaskPPOConfig.lr)
        self.scaler = GradScaler()
        
        # 初始化经验回放缓冲
        self.buffer = deque(maxlen=MultiTaskPPOConfig.max_steps)
​
    def collect_experience(self, num_steps):
        """并行收集多任务经验"""
        for _ in range(num_steps):
            task_id = int(np.random.randint(MultiTaskPPOConfig.num_tasks))
            env = self.envs[task_id]
            
            if not hasattr(env, '_last_obs'):
                state, _ = env.reset()
            else:
                state = env._last_obs
                
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(MultiTaskPPOConfig.device)
                # 将task_id转换为张量
                task_id_tensor = torch.tensor([task_id], dtype=torch.long, device=MultiTaskPPOConfig.device)
                action_mean, value = self.policy(state_tensor, task_id_tensor)
                dist = Normal(action_mean, torch.ones_like(action_mean))
                action = dist.sample().squeeze(0).cpu().numpy()
                log_prob = dist.log_prob(action_mean).detach()
            
            next_state, reward, done, trunc, _ = env.step(action)
            self.buffer.append({
                'state': state,
                'action': action,
                'log_prob': log_prob.cpu(),
                'reward': float(reward),
                'done': bool(done),
                'task_id': task_id,
                'value': float(value.item())
            })
            
            state = next_state if not (done or trunc) else env.reset()[0]
​
    def compute_gae(self, values, rewards, dones):
        """计算广义优势估计(GAE)"""
        advantages = []
        last_advantage = 0
        next_value = 0
        
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + MultiTaskPPOConfig.gamma * next_value * (1 - dones[t]) - values[t]
            last_advantage = delta + MultiTaskPPOConfig.gamma * MultiTaskPPOConfig.gae_lambda * (1 - dones[t]) * last_advantage
            advantages.append(last_advantage)
            next_value = values[t]
            
        advantages = torch.tensor(advantages[::-1], dtype=torch.float32).to(MultiTaskPPOConfig.device)
        returns = advantages + torch.tensor(values, dtype=torch.float32).to(MultiTaskPPOConfig.device)
        return (advantages - advantages.mean()) / (advantages.std() + 1e-8), returns
​
    def update_policy(self):
        """策略更新阶段正确转换张量"""
        if not self.buffer:
            return 0, 0
        
        """使用PPO进行策略优化"""
        # 从缓冲中提取数据
        batch = list(self.buffer)
        states = torch.tensor(
            [x['state'] for x in batch],
            dtype=torch.float32,
            device=MultiTaskPPOConfig.device
        )
        actions = torch.FloatTensor(np.array([x['action'] for x in batch])).to(MultiTaskPPOConfig.device)
        old_log_probs = torch.cat([x['log_prob'] for x in batch]).to(MultiTaskPPOConfig.device)
        rewards = torch.FloatTensor([x['reward'] for x in batch]).to(MultiTaskPPOConfig.device)
        dones = torch.FloatTensor([x['done'] for x in batch]).to(MultiTaskPPOConfig.device)
        task_ids = torch.tensor(
            [x['task_id'] for x in batch],
            dtype=torch.long,  # 必须指定为long类型
            device=MultiTaskPPOConfig.device
        )
        values = torch.FloatTensor([x['value'] for x in batch]).to(MultiTaskPPOConfig.device)
​
        # 计算GAE和returns
        advantages, returns = self.compute_gae(values.cpu().numpy(), rewards.cpu().numpy(), dones.cpu().numpy())
​
        # 自动混合精度训练
        with autocast():
            total_policy_loss = 0
            total_value_loss = 0
            
            for _ in range(MultiTaskPPOConfig.ppo_epochs):
                # 随机打乱数据
                perm = torch.randperm(len(batch))
                
                for i in range(0, len(batch), MultiTaskPPOConfig.batch_size):
                    idx = perm[i:i+MultiTaskPPOConfig.batch_size]
                    
                    # 获取小批量数据
                    batch_states = states[idx]
                    batch_actions = actions[idx]
                    batch_old_log_probs = old_log_probs[idx]
                    batch_returns = returns[idx]
                    batch_advantages = advantages[idx]
                    batch_task_ids = task_ids[idx]
                    
                    # 前向传播
                    action_means, new_values = self.policy(states, task_ids)
                    dist = Normal(action_means, torch.ones_like(action_means))
                    new_log_probs = dist.log_prob(batch_actions)
                    
                    # 计算重要性采样比率
                    ratio = (new_log_probs - batch_old_log_probs).exp()
                    
                    # 策略损失
                    surr1 = ratio * batch_advantages.unsqueeze(-1)
                    surr2 = torch.clamp(ratio, 1-MultiTaskPPOConfig.clip_epsilon, 
                                      1+MultiTaskPPOConfig.clip_epsilon) * batch_advantages.unsqueeze(-1)
                    policy_loss = -torch.min(surr1, surr2).mean()
                    
                    # 值函数损失
                    value_loss = 0.5 * (new_values.squeeze() - batch_returns).pow(2).mean()
                    
                    # 总损失
                    loss = policy_loss + value_loss
                    
                    # 反向传播
                    self.scaler.scale(loss).backward()
                    total_policy_loss += policy_loss.item()
                    total_value_loss += value_loss.item()
​
            # 梯度裁剪和参数更新
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), MultiTaskPPOConfig.grad_clip)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
​
        return total_policy_loss / MultiTaskPPOConfig.ppo_epochs, total_value_loss / MultiTaskPPOConfig.ppo_epochs
​
    def train(self):
        print(f"开始训练,设备:{MultiTaskPPOConfig.device}")
        start_time = time.time()
        episode_rewards = {i: deque(maxlen=100) for i in range(MultiTaskPPOConfig.num_tasks)}
        
        for episode in range(MultiTaskPPOConfig.max_episodes):
            # 经验收集阶段
            self.collect_experience(MultiTaskPPOConfig.max_steps)
            
            # 策略优化阶段
            policy_loss, value_loss = self.update_policy()
            
            # 记录统计信息
            task_id = np.random.randint(MultiTaskPPOConfig.num_tasks)
            episode_reward = sum(x['reward'] for x in self.buffer if x['task_id'] == task_id)
            episode_rewards[task_id].append(episode_reward)
            
            # 定期输出日志
            if (episode + 1) % 100 == 0:
                avg_rewards = {k: np.mean(v) if v else 0 for k, v in episode_rewards.items()}
                time_cost = time.time() - start_time
                print(f"Episode {episode+1:5d} | Time: {time_cost:6.1f}s")
                for task_id in range(MultiTaskPPOConfig.num_tasks):
                    task_name = MultiTaskPPOConfig.task_names[task_id]
                    print(f"  {task_name:25s} | Avg Reward: {avg_rewards[task_id]:7.2f}")
                print(f"  Policy Loss: {policy_loss:.4f} | Value Loss: {value_loss:.4f}\n")
                start_time = time.time()
​
if __name__ == "__main__":
    trainer = MultiTaskPPOTrainer()
    print(f"状态维度: {trainer.state_dim}, 动作维度: {trainer.action_dim}")
    trainer.train()

四、关键代码解析

  1. 共享策略网络

    • SharedPolicy 包含共享网络层和任务特定头部

    • task_headsvalue_heads 分别处理不同任务的动作和值函数

  2. 多任务采样机制

    • 每个回合随机选择一个任务进行训练

    • 动态切换环境实例 env = self.envs[task_id]

  3. 联合梯度更新

    • 计算多任务的策略损失和值函数损失

    • 通过 task_id 索引选择对应任务头部参数


五、训练输出示例

python 复制代码
状态维度: 39, 动作维度: 4
开始训练,设备:cuda
/workspace/e23.py:184: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)
  states = torch.tensor(
/workspace/e23.py:204: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast():
Episode   100 | Time:  931.2s
  reach-v2-goal-observable  | Avg Reward:  226.83
  push-v2-goal-observable   | Avg Reward:    8.82
  pick-place-v2-goal-observable | Avg Reward:    3.31
  Policy Loss: 0.0386 | Value Loss: 13.2587
​
Episode   200 | Time:  935.3s
  reach-v2-goal-observable  | Avg Reward:  227.12
  push-v2-goal-observable   | Avg Reward:    8.83
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0434 | Value Loss: 14.9413
​
Episode   300 | Time:  939.4s
  reach-v2-goal-observable  | Avg Reward:  226.78
  push-v2-goal-observable   | Avg Reward:    8.82
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0429 | Value Loss: 13.9076
​
Episode   400 | Time:  938.4s
  reach-v2-goal-observable  | Avg Reward:  225.74
  push-v2-goal-observable   | Avg Reward:    8.84
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0378 | Value Loss: 14.7157
​
Episode   500 | Time:  938.4s
  reach-v2-goal-observable  | Avg Reward:  225.45
  push-v2-goal-observable   | Avg Reward:    8.81
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0381 | Value Loss: 11.7940
​
Episode   600 | Time:  928.5s
  reach-v2-goal-observable  | Avg Reward:  225.39
  push-v2-goal-observable   | Avg Reward:    8.75
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0462 | Value Loss: 14.5566
​
Episode   700 | Time:  926.6s
  reach-v2-goal-observable  | Avg Reward:  226.37
  push-v2-goal-observable   | Avg Reward:    8.65
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0394 | Value Loss: 15.5556
​
Episode   800 | Time:  943.8s
  reach-v2-goal-observable  | Avg Reward:  224.72
  push-v2-goal-observable   | Avg Reward:    8.64
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0361 | Value Loss: 16.0126
​
Episode   900 | Time:  937.2s
  reach-v2-goal-observable  | Avg Reward:  224.15
  push-v2-goal-observable   | Avg Reward:    8.72
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0417 | Value Loss: 14.1907
​
Episode  1000 | Time:  940.7s
  reach-v2-goal-observable  | Avg Reward:  223.77
  push-v2-goal-observable   | Avg Reward:    8.73
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0399 | Value Loss: 16.0540
​
Episode  1100 | Time:  937.0s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.68
  pick-place-v2-goal-observable | Avg Reward:    3.17
  Policy Loss: 0.0409 | Value Loss: 15.5525
​
Episode  1200 | Time:  933.0s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.68
  pick-place-v2-goal-observable | Avg Reward:    3.17
  Policy Loss: 0.0388 | Value Loss: 17.4549
​
Episode  1300 | Time:  942.1s
  reach-v2-goal-observable  | Avg Reward:  224.35
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0447 | Value Loss: 14.6700
​
Episode  1400 | Time:  966.6s
  reach-v2-goal-observable  | Avg Reward:  224.27
  push-v2-goal-observable   | Avg Reward:    8.73
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0434 | Value Loss: 13.3487
​
Episode  1500 | Time:  943.0s
  reach-v2-goal-observable  | Avg Reward:  223.03
  push-v2-goal-observable   | Avg Reward:    8.69
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0438 | Value Loss: 14.7557
​
Episode  1600 | Time:  929.1s
  reach-v2-goal-observable  | Avg Reward:  224.01
  push-v2-goal-observable   | Avg Reward:    8.69
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0365 | Value Loss: 12.2506
​
Episode  1700 | Time:  937.9s
  reach-v2-goal-observable  | Avg Reward:  222.88
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0365 | Value Loss: 11.8954
​
Episode  1800 | Time:  930.1s
  reach-v2-goal-observable  | Avg Reward:  224.42
  push-v2-goal-observable   | Avg Reward:    8.75
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0437 | Value Loss: 13.6396
​
Episode  1900 | Time:  927.0s
  reach-v2-goal-observable  | Avg Reward:  224.66
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0360 | Value Loss: 14.3216
​
Episode  2000 | Time:  934.3s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.63
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0475 | Value Loss: 14.0712

六、总结与扩展

本文实现了多任务强化学习的核心范式------基于共享策略的 PPO 算法,展示了跨任务知识迁移的能力。读者可尝试以下扩展方向:

  1. 动态任务权重 根据任务难度自适应调整损失权重:

    python 复制代码
    # 在 update() 中添加任务权重
    task_weights = calculate_task_difficulty()
    loss = sum([weight * loss_i for weight, loss_i in zip(task_weights, losses)])
  2. 分层强化学习 引入高层策略调度任务:

    python 复制代码
    class MetaController(nn.Module):
        def __init__(self, num_tasks):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.ReLU(),
                nn.Linear(64, num_tasks)
            )
  3. 课程学习 从简单任务逐步过渡到复杂任务:

    python 复制代码
    def schedule_task():
        if episode < 1000:
            return 'reach-v2-goal-observable'
        elif episode < 2000:
            return 'push-v2-goal-observable'
        else:
            return 'pick-place-v2-goal-observable'

在下一篇文章中,我们将探索 分层强化学习(HRL),并实现 Option-Critic 算法!


注意事项

1.安装依赖:

bash 复制代码
pip install metaworld gymnasium torch

2.metaworld问题:

如果稳定版存在问题,尝试安装GitHub上的最新版:

bash 复制代码
pip install git+https://github.com/rlworkgroup/metaworld.git@master
相关推荐
涛思数据(TDengine)5 小时前
直播预告 | TDgpt 智能体发布 & 时序数据库 TDengine 3.3.6 发布会即将开启
人工智能·时序数据库·tdengine
Wnq100726 小时前
DEEPSEEK 唤醒企业视频第二春
人工智能·嵌入式硬件·物联网·机器人·音视频·iot
蹦蹦跳跳真可爱5896 小时前
Python----计算机视觉处理(Opencv:梯度处理:filiter2D算子,Sobel,Laplacian)
人工智能·python·opencv·计算机视觉
三桥君7 小时前
DeepSeek助力文案,智能音箱如何改变你的生活?
人工智能·生活·智能音箱·deepseek
蹦蹦跳跳真可爱5897 小时前
Python----计算机视觉处理(Opencv:绘制图像轮廓:寻找轮廓,findContours()函数)
人工智能·python·opencv·计算机视觉
Ronin-Lotus8 小时前
深度学习篇---卷积网络结构
人工智能·python·深度学习·cnn
石硕页8 小时前
智能体是如何学习文档的内容的?【deepseek生成】
人工智能·rag·智能体
訾博ZiBo8 小时前
AI日报 - 2025年03月28日
人工智能
机械心9 小时前
自动驾驶VLA模型技术解析与模型设计
人工智能·机器学习·自动驾驶·vla·端到端自动驾驶
Fansv5879 小时前
深度学习框架PyTorch——从入门到精通(10)PyTorch张量简介
人工智能·pytorch·经验分享·python·深度学习·机器学习