PyTorch 深度学习实战(15):Twin Delayed DDPG (TD3) 算法

在上一篇文章中,我们介绍了 Deep Deterministic Policy Gradient (DDPG) 算法,并使用它解决了 Pendulum 问题。本文将深入探讨 Twin Delayed DDPG (TD3) 算法,这是一种改进的 DDPG 算法,能够有效解决 DDPG 中的过估计问题。我们将使用 PyTorch 实现 TD3 算法,并应用于经典的 Pendulum 问题。


一、TD3 算法基础

TD3 是 DDPG 的改进版本,通过引入以下三个关键技术来解决 DDPG 中的过估计问题:

  1. 双重 Critic 网络

    使用两个 Critic 网络来估计 Q 值,从而减少过估计问题。

  2. 延迟更新

    延迟 Actor 网络的更新,确保 Critic 网络更稳定地收敛。

  3. 目标策略平滑

    在目标动作中加入噪声,从而减少 Critic 网络的过拟合。

1. TD3 的核心思想

  • 双重 Critic 网络

    • 使用两个 Critic 网络来估计 Q 值,取两者中的较小值作为目标 Q 值,从而减少过估计。
  • 延迟更新

    • 每更新 Critic 网络多次,才更新一次 Actor 网络,确保 Critic 网络更稳定地收敛。
  • 目标策略平滑

    • 在目标动作中加入噪声,从而减少 Critic 网络的过拟合。

2. TD3 的优势

  • 减少过估计

    • 通过双重 Critic 网络和目标策略平滑,TD3 能够有效减少 Q 值的过估计。
  • 训练稳定

    • 延迟更新策略确保 Critic 网络更稳定地收敛。
  • 适用于连续动作空间

    • TD3 能够直接输出连续动作,适用于机器人控制、自动驾驶等任务。

3. TD3 的算法流程

  1. 使用当前策略采样一批数据。

  2. 使用目标网络计算目标 Q 值。

  3. 更新 Critic 网络以最小化 Q 值的误差。

  4. 延迟更新 Actor 网络以最大化 Q 值。

  5. 更新目标网络。

  6. 重复上述过程,直到策略收敛。


二、Pendulum 问题实战

我们将使用 PyTorch 实现 TD3 算法,并应用于 Pendulum 问题。目标是控制摆杆使其保持直立。

1. 问题描述

Pendulum 环境的状态空间包括摆杆的角度和角速度。动作空间是一个连续的扭矩值,范围在 −2,2 之间。智能体每保持摆杆直立一步,就会获得一个负的奖励,目标是最大化累积奖励。

2. 实现步骤

  1. 安装并导入必要的库。

  2. 定义 Actor 网络和 Critic 网络。

  3. 定义 TD3 训练过程。

  4. 测试模型并评估性能。

3. 代码实现

以下是完整的代码实现:

python 复制代码
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
​
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
​
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
​
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
​
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
​
​
# 改进的 Actor 网络(增加层归一化)
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 256)
        self.ln1 = nn.LayerNorm(256)
        self.l2 = nn.Linear(256, 256)
        self.ln2 = nn.LayerNorm(256)
        self.l3 = nn.Linear(256, action_dim)
        self.max_action = max_action
​
    def forward(self, x):
        x = F.relu(self.ln1(self.l1(x)))
        x = F.relu(self.ln2(self.l2(x)))
        x = torch.tanh(self.l3(x)) * self.max_action
        return x
​
​
# 改进的 Critic 网络(增加层归一化)
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.ln1 = nn.LayerNorm(256)
        self.l2 = nn.Linear(256, 256)
        self.ln2 = nn.LayerNorm(256)
        self.l3 = nn.Linear(256, 1)
​
    def forward(self, x, u):
        x = F.relu(self.ln1(self.l1(torch.cat([x, u], 1))))
        x = F.relu(self.ln2(self.l2(x)))
        x = self.l3(x)
        return x
​
​
class TD3:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
​
        self.critic1 = Critic(state_dim, action_dim).to(device)
        self.critic2 = Critic(state_dim, action_dim).to(device)
        self.critic1_target = Critic(state_dim, action_dim).to(device)
        self.critic2_target = Critic(state_dim, action_dim).to(device)
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target.load_state_dict(self.critic2.state_dict())
        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=3e-4)
        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=3e-4)
​
        self.max_action = max_action
        self.replay_buffer = deque(maxlen=1000000)
        self.batch_size = 256
        self.gamma = 0.99
        self.tau = 0.005
        self.policy_noise = 0.2
        self.noise_clip = 0.5
        self.policy_freq = 2
        self.total_it = 0
        self.exploration_noise = 0.1  # 新增探索噪声
​
    def select_action(self, state, add_noise=True):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        action = self.actor(state).cpu().data.numpy().flatten()
        if add_noise:
            noise = np.random.normal(0, self.exploration_noise, size=action_dim)
            action = (action + noise).clip(-self.max_action, self.max_action)
        return action
​
    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return
​
        self.total_it += 1
​
        batch = random.sample(self.replay_buffer, self.batch_size)
        state = torch.FloatTensor(np.array([t[0] for t in batch])).to(device)
        action = torch.FloatTensor(np.array([t[1] for t in batch])).to(device)
        reward = torch.FloatTensor(np.array([t[2] for t in batch])).reshape(-1, 1).to(device) / 10.0  # 奖励缩放
        next_state = torch.FloatTensor(np.array([t[3] for t in batch])).to(device)
        done = torch.FloatTensor(np.array([t[4] for t in batch])).reshape(-1, 1).to(device)
​
        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            target_Q1 = self.critic1_target(next_state, next_action)
            target_Q2 = self.critic2_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (1 - done) * self.gamma * target_Q
​
        current_Q1 = self.critic1(state, action)
        current_Q2 = self.critic2(state, action)
        critic1_loss = F.mse_loss(current_Q1, target_Q)
        critic2_loss = F.mse_loss(current_Q2, target_Q)
        self.critic1_optimizer.zero_grad()
        critic1_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic1.parameters(), 1.0)  # 梯度裁剪
        self.critic1_optimizer.step()
        self.critic2_optimizer.zero_grad()
        critic2_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), 1.0)
        self.critic2_optimizer.step()
​
        if self.total_it % self.policy_freq == 0:
            actor_loss = -self.critic1(state, self.actor(state)).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
            self.actor_optimizer.step()
​
            for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
​
    def save(self, filename):
        torch.save(self.actor.state_dict(), filename + "_actor.pth")
​
​
def train_td3(env, agent, episodes=2000, early_stop_threshold=-150):
    rewards_history = []
    moving_avg = []
    best_avg = -np.inf
​
    for ep in range(episodes):
        state,_ = env.reset()
        episode_reward = 0
        done = False
        step = 0
​
        while not done:
            # 线性衰减探索噪声
            if ep < 300:
                agent.exploration_noise = max(0.5 * (1 - ep / 300), 0.1)
            else:
                agent.exploration_noise = 0.1
​
            action = agent.select_action(state, add_noise=(ep < 100))  # 前100轮强制探索
            next_state, reward, done, _, _ = env.step(action)
            agent.replay_buffer.append((state, action, reward, next_state, done))
            state = next_state
            episode_reward += reward
            agent.train()
            step += 1
​
        rewards_history.append(episode_reward)
        current_avg = np.mean(rewards_history[-50:])
        moving_avg.append(current_avg)
​
        if current_avg > best_avg:
            best_avg = current_avg
            agent.save("td3_pendulum_best")
​
        if (ep + 1) % 50 == 0:
            print(f"Episode: {ep + 1}, Avg Reward: {current_avg:.2f}")
​
        # 早停机制
        if current_avg >= early_stop_threshold:
            print(f"早停触发,平均奖励达到 {current_avg:.2f}")
            break
​
    return moving_avg, rewards_history
​
​
# 训练并可视化
td3_agent = TD3(state_dim, action_dim, max_action)
moving_avg, rewards_history = train_td3(env, td3_agent, episodes=2000)
​
# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.6, label='single round reward')
plt.plot(moving_avg, 'r-', linewidth=2, label='moving average (50 rounds)')
plt.xlabel('episodes')
plt.ylabel('reward')
plt.title('TD3 training performance on Pendulum-v1')
plt.legend()
plt.grid(True)
plt.show()

三、代码解析

  1. Actor 和 Critic 网络

    • Actor 网络输出连续动作,通过 tanh 函数将动作限制在 −max_action,max_action 范围内。

    • Critic 网络输出状态-动作对的 Q 值。

  2. TD3 训练过程

    • 使用当前策略采样一批数据。

    • 使用目标网络计算目标 Q 值。

    • 更新 Critic 网络以最小化 Q 值的误差。

    • 延迟更新 Actor 网络以最大化 Q 值。

    • 更新目标网络。

  3. 训练过程

    • 在训练过程中,每 50 个 episode 打印一次平均奖励。

    • 训练结束后,绘制训练过程中的总奖励曲线。


四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。

  • 训练结束后,绘制训练过程中的总奖励曲线。


五、总结

本文介绍了 TD3 算法的基本原理,并使用 PyTorch 实现了一个简单的 TD3 模型来解决 Pendulum 问题。通过这个例子,我们学习了如何使用 TD3 算法进行连续动作空间的策略优化。

在下一篇文章中,我们将探讨更高级的强化学习算法,如 Soft Actor-Critic (SAC)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。

  • 如果你有 GPU,代码会自动检测并使用 GPU 加速。

希望这篇文章能帮助你更好地理解 TD3 算法!如果有任何问题,欢迎在评论区留言讨论。

相关推荐
胡耀超38 分钟前
大模型架构演进全景:从Transformer到下一代智能系统的技术路径(MoE、Mamba/SSM、混合架构)
人工智能·深度学习·ai·架构·大模型·transformer·技术趋势分析
Luchang-Li2 小时前
sglang pytorch NCCL hang分析
pytorch·python·nccl
金融小师妹2 小时前
基于哈塞特独立性表态的AI量化研究:美联储政策独立性的多维验证
大数据·人工智能·算法
纪元A梦5 小时前
贪心算法应用:化工反应器调度问题详解
算法·贪心算法
深圳市快瞳科技有限公司6 小时前
小场景大市场:猫狗识别算法在宠物智能设备中的应用
算法·计算机视觉·宠物
liulilittle6 小时前
OPENPPP2 —— IP标准校验和算法深度剖析:从原理到SSE2优化实现
网络·c++·网络协议·tcp/ip·算法·ip·通信
superlls8 小时前
(算法 哈希表)【LeetCode 349】两个数组的交集 思路笔记自留
java·数据结构·算法
Gyoku Mint9 小时前
提示词工程(Prompt Engineering)的崛起——为什么“会写Prompt”成了新技能?
人工智能·pytorch·深度学习·神经网络·语言模型·自然语言处理·nlp
田里的水稻9 小时前
C++_队列编码实例,从末端添加对象,同时把头部的对象剔除掉,中的队列长度为设置长度NUM_OBJ
java·c++·算法
纪元A梦9 小时前
贪心算法应用:保险理赔调度问题详解
算法·贪心算法