PyTorch 深度学习实战(23):多任务强化学习(Multi-Task RL)

一、多任务强化学习原理

1. 多任务学习核心思想

多任务强化学习(Multi-Task RL)旨在让智能体同时学习多个任务 ,通过共享知识提升学习效率和泛化能力。与单任务强化学习的区别在于:

对比维度 单任务强化学习 多任务强化学习
目标 优化单一任务策略 同时优化多个任务的共享策略
训练方式 单任务独立训练 多任务联合训练
知识迁移 共享表示或参数实现跨任务知识迁移
应用场景 任务特定场景 复杂环境中的通用智能体
2. 基于共享表示的多任务框架

通过共享网络层 学习任务共性,任务特定层处理任务差异。算法流程如下:

  1. 任务采样:从任务分布中随机选择一个任务

  2. 策略执行:基于共享网络生成动作

  3. 梯度更新:联合优化共享参数和任务特定参数

数学表达:


二、多任务 PPO 算法实现(基于 Gymnasium)

我们将以 Meta-World 多任务机械臂环境 为例,实现基于 PPO 的多任务强化学习:

  1. 定义任务集合 :包含 reachpushpick-place 等任务

  2. 构建共享策略网络:共享卷积层 + 任务特定全连接层

  3. 实现多任务采样:动态切换任务训练

  4. 联合梯度更新:平衡多任务损失


三、代码实现

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
from torch.cuda.amp import autocast, GradScaler
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
from collections import deque
​
# ================== 配置参数 ==================
class MultiTaskPPOConfig:
    task_names = [
        'reach-v2-goal-observable',
        'push-v2-goal-observable',
        'pick-place-v2-goal-observable'
    ]
    num_tasks = 3
    hidden_dim = 512
    task_specific_dim = 128
    lr = 3e-4
    gamma = 0.99
    gae_lambda = 0.95
    clip_epsilon = 0.2
    ppo_epochs = 4
    batch_size = 512
    max_episodes = 2000
    max_steps = 500
    grad_clip = 0.5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== 共享策略网络 ==================
class SharedPolicy(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.action_dim = action_dim
        self.shared_net = nn.Sequential(
            nn.Linear(state_dim, MultiTaskPPOConfig.hidden_dim),
            nn.LayerNorm(MultiTaskPPOConfig.hidden_dim),
            nn.GELU(),
            nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.hidden_dim),
            nn.GELU()
        )
        
        # 多任务头部
        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
                nn.GELU(),
                nn.Linear(MultiTaskPPOConfig.task_specific_dim, action_dim)
            ) for _ in range(MultiTaskPPOConfig.num_tasks)
        ])
        
        self.value_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
                nn.GELU(),
                nn.Linear(MultiTaskPPOConfig.task_specific_dim, 1)
            ) for _ in range(MultiTaskPPOConfig.num_tasks)
        ])
​
    def forward(self, states, task_ids):
        shared_features = self.shared_net(states)
        batch_size = states.size(0)
        
        # 初始化与输入相同dtype的输出张量
        action_means = torch.zeros_like(
            states[:, :self.action_dim],  # 假设states维度足够
            dtype=states.dtype, 
            device=states.device
        )
        values = torch.zeros(
            batch_size, 1, 
            dtype=states.dtype, 
            device=states.device
        )
        
        unique_task_ids = torch.unique(task_ids)
        
        for task_id_tensor in unique_task_ids:
            task_id = task_id_tensor.item()
            mask = (task_ids == task_id_tensor)
            
            if not mask.any():
                continue
                
            selected_features = shared_features[mask]
            
            # 显式转换输出类型到states.dtype (通常是float32)
            task_action = self.task_heads[task_id](selected_features).to(dtype=states.dtype)
            task_value = self.value_heads[task_id](selected_features).to(dtype=states.dtype)
            
            action_means[mask] = task_action
            values[mask] = task_value
            
        return action_means, values
​
# ================== 训练系统 ==================
class MultiTaskPPOTrainer:
    def __init__(self):
        # 初始化多任务环境
        self.envs = []
        self.state_dim = None
        self.action_dim = None
        
        # 验证环境并获取维度
        for task_name in MultiTaskPPOConfig.task_names:
            env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[task_name]()
            obs, _ = env.reset()
            
            if self.state_dim is None:
                self.state_dim = obs.shape[0]
                self.action_dim = env.action_space.shape[0]
            else:
                assert obs.shape[0] == self.state_dim, f"状态维度不一致: {task_name}"
                
            self.envs.append(env)
        
        # 初始化策略网络
        self.policy = SharedPolicy(self.state_dim, self.action_dim).to(MultiTaskPPOConfig.device)
        self.optimizer = optim.AdamW(self.policy.parameters(), lr=MultiTaskPPOConfig.lr)
        self.scaler = GradScaler()
        
        # 初始化经验回放缓冲
        self.buffer = deque(maxlen=MultiTaskPPOConfig.max_steps)
​
    def collect_experience(self, num_steps):
        """并行收集多任务经验"""
        for _ in range(num_steps):
            task_id = int(np.random.randint(MultiTaskPPOConfig.num_tasks))
            env = self.envs[task_id]
            
            if not hasattr(env, '_last_obs'):
                state, _ = env.reset()
            else:
                state = env._last_obs
                
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(MultiTaskPPOConfig.device)
                # 将task_id转换为张量
                task_id_tensor = torch.tensor([task_id], dtype=torch.long, device=MultiTaskPPOConfig.device)
                action_mean, value = self.policy(state_tensor, task_id_tensor)
                dist = Normal(action_mean, torch.ones_like(action_mean))
                action = dist.sample().squeeze(0).cpu().numpy()
                log_prob = dist.log_prob(action_mean).detach()
            
            next_state, reward, done, trunc, _ = env.step(action)
            self.buffer.append({
                'state': state,
                'action': action,
                'log_prob': log_prob.cpu(),
                'reward': float(reward),
                'done': bool(done),
                'task_id': task_id,
                'value': float(value.item())
            })
            
            state = next_state if not (done or trunc) else env.reset()[0]
​
    def compute_gae(self, values, rewards, dones):
        """计算广义优势估计(GAE)"""
        advantages = []
        last_advantage = 0
        next_value = 0
        
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + MultiTaskPPOConfig.gamma * next_value * (1 - dones[t]) - values[t]
            last_advantage = delta + MultiTaskPPOConfig.gamma * MultiTaskPPOConfig.gae_lambda * (1 - dones[t]) * last_advantage
            advantages.append(last_advantage)
            next_value = values[t]
            
        advantages = torch.tensor(advantages[::-1], dtype=torch.float32).to(MultiTaskPPOConfig.device)
        returns = advantages + torch.tensor(values, dtype=torch.float32).to(MultiTaskPPOConfig.device)
        return (advantages - advantages.mean()) / (advantages.std() + 1e-8), returns
​
    def update_policy(self):
        """策略更新阶段正确转换张量"""
        if not self.buffer:
            return 0, 0
        
        """使用PPO进行策略优化"""
        # 从缓冲中提取数据
        batch = list(self.buffer)
        states = torch.tensor(
            [x['state'] for x in batch],
            dtype=torch.float32,
            device=MultiTaskPPOConfig.device
        )
        actions = torch.FloatTensor(np.array([x['action'] for x in batch])).to(MultiTaskPPOConfig.device)
        old_log_probs = torch.cat([x['log_prob'] for x in batch]).to(MultiTaskPPOConfig.device)
        rewards = torch.FloatTensor([x['reward'] for x in batch]).to(MultiTaskPPOConfig.device)
        dones = torch.FloatTensor([x['done'] for x in batch]).to(MultiTaskPPOConfig.device)
        task_ids = torch.tensor(
            [x['task_id'] for x in batch],
            dtype=torch.long,  # 必须指定为long类型
            device=MultiTaskPPOConfig.device
        )
        values = torch.FloatTensor([x['value'] for x in batch]).to(MultiTaskPPOConfig.device)
​
        # 计算GAE和returns
        advantages, returns = self.compute_gae(values.cpu().numpy(), rewards.cpu().numpy(), dones.cpu().numpy())
​
        # 自动混合精度训练
        with autocast():
            total_policy_loss = 0
            total_value_loss = 0
            
            for _ in range(MultiTaskPPOConfig.ppo_epochs):
                # 随机打乱数据
                perm = torch.randperm(len(batch))
                
                for i in range(0, len(batch), MultiTaskPPOConfig.batch_size):
                    idx = perm[i:i+MultiTaskPPOConfig.batch_size]
                    
                    # 获取小批量数据
                    batch_states = states[idx]
                    batch_actions = actions[idx]
                    batch_old_log_probs = old_log_probs[idx]
                    batch_returns = returns[idx]
                    batch_advantages = advantages[idx]
                    batch_task_ids = task_ids[idx]
                    
                    # 前向传播
                    action_means, new_values = self.policy(states, task_ids)
                    dist = Normal(action_means, torch.ones_like(action_means))
                    new_log_probs = dist.log_prob(batch_actions)
                    
                    # 计算重要性采样比率
                    ratio = (new_log_probs - batch_old_log_probs).exp()
                    
                    # 策略损失
                    surr1 = ratio * batch_advantages.unsqueeze(-1)
                    surr2 = torch.clamp(ratio, 1-MultiTaskPPOConfig.clip_epsilon, 
                                      1+MultiTaskPPOConfig.clip_epsilon) * batch_advantages.unsqueeze(-1)
                    policy_loss = -torch.min(surr1, surr2).mean()
                    
                    # 值函数损失
                    value_loss = 0.5 * (new_values.squeeze() - batch_returns).pow(2).mean()
                    
                    # 总损失
                    loss = policy_loss + value_loss
                    
                    # 反向传播
                    self.scaler.scale(loss).backward()
                    total_policy_loss += policy_loss.item()
                    total_value_loss += value_loss.item()
​
            # 梯度裁剪和参数更新
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.policy.parameters(), MultiTaskPPOConfig.grad_clip)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
​
        return total_policy_loss / MultiTaskPPOConfig.ppo_epochs, total_value_loss / MultiTaskPPOConfig.ppo_epochs
​
    def train(self):
        print(f"开始训练,设备:{MultiTaskPPOConfig.device}")
        start_time = time.time()
        episode_rewards = {i: deque(maxlen=100) for i in range(MultiTaskPPOConfig.num_tasks)}
        
        for episode in range(MultiTaskPPOConfig.max_episodes):
            # 经验收集阶段
            self.collect_experience(MultiTaskPPOConfig.max_steps)
            
            # 策略优化阶段
            policy_loss, value_loss = self.update_policy()
            
            # 记录统计信息
            task_id = np.random.randint(MultiTaskPPOConfig.num_tasks)
            episode_reward = sum(x['reward'] for x in self.buffer if x['task_id'] == task_id)
            episode_rewards[task_id].append(episode_reward)
            
            # 定期输出日志
            if (episode + 1) % 100 == 0:
                avg_rewards = {k: np.mean(v) if v else 0 for k, v in episode_rewards.items()}
                time_cost = time.time() - start_time
                print(f"Episode {episode+1:5d} | Time: {time_cost:6.1f}s")
                for task_id in range(MultiTaskPPOConfig.num_tasks):
                    task_name = MultiTaskPPOConfig.task_names[task_id]
                    print(f"  {task_name:25s} | Avg Reward: {avg_rewards[task_id]:7.2f}")
                print(f"  Policy Loss: {policy_loss:.4f} | Value Loss: {value_loss:.4f}\n")
                start_time = time.time()
​
if __name__ == "__main__":
    trainer = MultiTaskPPOTrainer()
    print(f"状态维度: {trainer.state_dim}, 动作维度: {trainer.action_dim}")
    trainer.train()

四、关键代码解析

  1. 共享策略网络

    • SharedPolicy 包含共享网络层和任务特定头部

    • task_headsvalue_heads 分别处理不同任务的动作和值函数

  2. 多任务采样机制

    • 每个回合随机选择一个任务进行训练

    • 动态切换环境实例 env = self.envs[task_id]

  3. 联合梯度更新

    • 计算多任务的策略损失和值函数损失

    • 通过 task_id 索引选择对应任务头部参数


五、训练输出示例

python 复制代码
状态维度: 39, 动作维度: 4
开始训练,设备:cuda
/workspace/e23.py:184: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)
  states = torch.tensor(
/workspace/e23.py:204: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with autocast():
Episode   100 | Time:  931.2s
  reach-v2-goal-observable  | Avg Reward:  226.83
  push-v2-goal-observable   | Avg Reward:    8.82
  pick-place-v2-goal-observable | Avg Reward:    3.31
  Policy Loss: 0.0386 | Value Loss: 13.2587
​
Episode   200 | Time:  935.3s
  reach-v2-goal-observable  | Avg Reward:  227.12
  push-v2-goal-observable   | Avg Reward:    8.83
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0434 | Value Loss: 14.9413
​
Episode   300 | Time:  939.4s
  reach-v2-goal-observable  | Avg Reward:  226.78
  push-v2-goal-observable   | Avg Reward:    8.82
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0429 | Value Loss: 13.9076
​
Episode   400 | Time:  938.4s
  reach-v2-goal-observable  | Avg Reward:  225.74
  push-v2-goal-observable   | Avg Reward:    8.84
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0378 | Value Loss: 14.7157
​
Episode   500 | Time:  938.4s
  reach-v2-goal-observable  | Avg Reward:  225.45
  push-v2-goal-observable   | Avg Reward:    8.81
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0381 | Value Loss: 11.7940
​
Episode   600 | Time:  928.5s
  reach-v2-goal-observable  | Avg Reward:  225.39
  push-v2-goal-observable   | Avg Reward:    8.75
  pick-place-v2-goal-observable | Avg Reward:    3.20
  Policy Loss: 0.0462 | Value Loss: 14.5566
​
Episode   700 | Time:  926.6s
  reach-v2-goal-observable  | Avg Reward:  226.37
  push-v2-goal-observable   | Avg Reward:    8.65
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0394 | Value Loss: 15.5556
​
Episode   800 | Time:  943.8s
  reach-v2-goal-observable  | Avg Reward:  224.72
  push-v2-goal-observable   | Avg Reward:    8.64
  pick-place-v2-goal-observable | Avg Reward:    3.23
  Policy Loss: 0.0361 | Value Loss: 16.0126
​
Episode   900 | Time:  937.2s
  reach-v2-goal-observable  | Avg Reward:  224.15
  push-v2-goal-observable   | Avg Reward:    8.72
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0417 | Value Loss: 14.1907
​
Episode  1000 | Time:  940.7s
  reach-v2-goal-observable  | Avg Reward:  223.77
  push-v2-goal-observable   | Avg Reward:    8.73
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0399 | Value Loss: 16.0540
​
Episode  1100 | Time:  937.0s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.68
  pick-place-v2-goal-observable | Avg Reward:    3.17
  Policy Loss: 0.0409 | Value Loss: 15.5525
​
Episode  1200 | Time:  933.0s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.68
  pick-place-v2-goal-observable | Avg Reward:    3.17
  Policy Loss: 0.0388 | Value Loss: 17.4549
​
Episode  1300 | Time:  942.1s
  reach-v2-goal-observable  | Avg Reward:  224.35
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0447 | Value Loss: 14.6700
​
Episode  1400 | Time:  966.6s
  reach-v2-goal-observable  | Avg Reward:  224.27
  push-v2-goal-observable   | Avg Reward:    8.73
  pick-place-v2-goal-observable | Avg Reward:    3.19
  Policy Loss: 0.0434 | Value Loss: 13.3487
​
Episode  1500 | Time:  943.0s
  reach-v2-goal-observable  | Avg Reward:  223.03
  push-v2-goal-observable   | Avg Reward:    8.69
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0438 | Value Loss: 14.7557
​
Episode  1600 | Time:  929.1s
  reach-v2-goal-observable  | Avg Reward:  224.01
  push-v2-goal-observable   | Avg Reward:    8.69
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0365 | Value Loss: 12.2506
​
Episode  1700 | Time:  937.9s
  reach-v2-goal-observable  | Avg Reward:  222.88
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.21
  Policy Loss: 0.0365 | Value Loss: 11.8954
​
Episode  1800 | Time:  930.1s
  reach-v2-goal-observable  | Avg Reward:  224.42
  push-v2-goal-observable   | Avg Reward:    8.75
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0437 | Value Loss: 13.6396
​
Episode  1900 | Time:  927.0s
  reach-v2-goal-observable  | Avg Reward:  224.66
  push-v2-goal-observable   | Avg Reward:    8.71
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0360 | Value Loss: 14.3216
​
Episode  2000 | Time:  934.3s
  reach-v2-goal-observable  | Avg Reward:  224.73
  push-v2-goal-observable   | Avg Reward:    8.63
  pick-place-v2-goal-observable | Avg Reward:    3.18
  Policy Loss: 0.0475 | Value Loss: 14.0712

六、总结与扩展

本文实现了多任务强化学习的核心范式------基于共享策略的 PPO 算法,展示了跨任务知识迁移的能力。读者可尝试以下扩展方向:

  1. 动态任务权重 根据任务难度自适应调整损失权重:

    python 复制代码
    # 在 update() 中添加任务权重
    task_weights = calculate_task_difficulty()
    loss = sum([weight * loss_i for weight, loss_i in zip(task_weights, losses)])
  2. 分层强化学习 引入高层策略调度任务:

    python 复制代码
    class MetaController(nn.Module):
        def __init__(self, num_tasks):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.ReLU(),
                nn.Linear(64, num_tasks)
            )
  3. 课程学习 从简单任务逐步过渡到复杂任务:

    python 复制代码
    def schedule_task():
        if episode < 1000:
            return 'reach-v2-goal-observable'
        elif episode < 2000:
            return 'push-v2-goal-observable'
        else:
            return 'pick-place-v2-goal-observable'

在下一篇文章中,我们将探索 分层强化学习(HRL),并实现 Option-Critic 算法!


注意事项

1.安装依赖:

bash 复制代码
pip install metaworld gymnasium torch

2.metaworld问题:

如果稳定版存在问题,尝试安装GitHub上的最新版:

bash 复制代码
pip install git+https://github.com/rlworkgroup/metaworld.git@master
相关推荐
腾讯云开发者23 分钟前
腾讯云TVP走进美的,共探智能制造新范式
人工智能
一水鉴天24 分钟前
整体设计 逻辑系统程序 之34七层网络的中台架构设计及链路对应讨论(含 CFR 规则与理 / 事代理界定)
人工智能·算法·公共逻辑
我星期八休息30 分钟前
C++智能指针全面解析:原理、使用场景与最佳实践
java·大数据·开发语言·jvm·c++·人工智能·python
ECT-OS-JiuHuaShan35 分钟前
《元推理框架技术白皮书》,人工智能领域的“杂交水稻“
人工智能·aigc·学习方法·量子计算·空间计算
minhuan39 分钟前
构建AI智能体:六十八、集成学习:从三个臭皮匠到AI集体智慧的深度解析
人工智能·机器学习·adaboost·集成学习·bagging
java1234_小锋1 小时前
TensorFlow2 Python深度学习 - 循环神经网络(SimpleRNN)示例
python·深度学习·tensorflow·tensorflow2
java1234_小锋1 小时前
TensorFlow2 Python深度学习 - 通俗理解池化层,卷积层以及全连接层
python·深度学习·tensorflow·tensorflow2
ssshooter1 小时前
MCP 服务 Streamable HTTP 和 SSE 的区别
人工智能·面试·程序员
rengang661 小时前
软件工程新纪元:AI协同编程架构师的修养与使命
人工智能·软件工程·ai编程·ai协同编程架构师
IT_陈寒1 小时前
Python+AI实战:用LangChain构建智能问答系统的5个核心技巧
前端·人工智能·后端