PyTorch 深度学习实战(15):Twin Delayed DDPG (TD3) 算法

在上一篇文章中,我们介绍了 Deep Deterministic Policy Gradient (DDPG) 算法,并使用它解决了 Pendulum 问题。本文将深入探讨 Twin Delayed DDPG (TD3) 算法,这是一种改进的 DDPG 算法,能够有效解决 DDPG 中的过估计问题。我们将使用 PyTorch 实现 TD3 算法,并应用于经典的 Pendulum 问题。


一、TD3 算法基础

TD3 是 DDPG 的改进版本,通过引入以下三个关键技术来解决 DDPG 中的过估计问题:

  1. 双重 Critic 网络

    使用两个 Critic 网络来估计 Q 值,从而减少过估计问题。

  2. 延迟更新

    延迟 Actor 网络的更新,确保 Critic 网络更稳定地收敛。

  3. 目标策略平滑

    在目标动作中加入噪声,从而减少 Critic 网络的过拟合。

1. TD3 的核心思想

  • 双重 Critic 网络

    • 使用两个 Critic 网络来估计 Q 值,取两者中的较小值作为目标 Q 值,从而减少过估计。
  • 延迟更新

    • 每更新 Critic 网络多次,才更新一次 Actor 网络,确保 Critic 网络更稳定地收敛。
  • 目标策略平滑

    • 在目标动作中加入噪声,从而减少 Critic 网络的过拟合。

2. TD3 的优势

  • 减少过估计

    • 通过双重 Critic 网络和目标策略平滑,TD3 能够有效减少 Q 值的过估计。
  • 训练稳定

    • 延迟更新策略确保 Critic 网络更稳定地收敛。
  • 适用于连续动作空间

    • TD3 能够直接输出连续动作,适用于机器人控制、自动驾驶等任务。

3. TD3 的算法流程

  1. 使用当前策略采样一批数据。

  2. 使用目标网络计算目标 Q 值。

  3. 更新 Critic 网络以最小化 Q 值的误差。

  4. 延迟更新 Actor 网络以最大化 Q 值。

  5. 更新目标网络。

  6. 重复上述过程,直到策略收敛。


二、Pendulum 问题实战

我们将使用 PyTorch 实现 TD3 算法,并应用于 Pendulum 问题。目标是控制摆杆使其保持直立。

1. 问题描述

Pendulum 环境的状态空间包括摆杆的角度和角速度。动作空间是一个连续的扭矩值,范围在 −2,2 之间。智能体每保持摆杆直立一步,就会获得一个负的奖励,目标是最大化累积奖励。

2. 实现步骤

  1. 安装并导入必要的库。

  2. 定义 Actor 网络和 Critic 网络。

  3. 定义 TD3 训练过程。

  4. 测试模型并评估性能。

3. 代码实现

以下是完整的代码实现:

python 复制代码
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
​
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
​
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
​
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
​
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
​
​
# 改进的 Actor 网络(增加层归一化)
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.l1 = nn.Linear(state_dim, 256)
        self.ln1 = nn.LayerNorm(256)
        self.l2 = nn.Linear(256, 256)
        self.ln2 = nn.LayerNorm(256)
        self.l3 = nn.Linear(256, action_dim)
        self.max_action = max_action
​
    def forward(self, x):
        x = F.relu(self.ln1(self.l1(x)))
        x = F.relu(self.ln2(self.l2(x)))
        x = torch.tanh(self.l3(x)) * self.max_action
        return x
​
​
# 改进的 Critic 网络(增加层归一化)
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.ln1 = nn.LayerNorm(256)
        self.l2 = nn.Linear(256, 256)
        self.ln2 = nn.LayerNorm(256)
        self.l3 = nn.Linear(256, 1)
​
    def forward(self, x, u):
        x = F.relu(self.ln1(self.l1(torch.cat([x, u], 1))))
        x = F.relu(self.ln2(self.l2(x)))
        x = self.l3(x)
        return x
​
​
class TD3:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
​
        self.critic1 = Critic(state_dim, action_dim).to(device)
        self.critic2 = Critic(state_dim, action_dim).to(device)
        self.critic1_target = Critic(state_dim, action_dim).to(device)
        self.critic2_target = Critic(state_dim, action_dim).to(device)
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target.load_state_dict(self.critic2.state_dict())
        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=3e-4)
        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=3e-4)
​
        self.max_action = max_action
        self.replay_buffer = deque(maxlen=1000000)
        self.batch_size = 256
        self.gamma = 0.99
        self.tau = 0.005
        self.policy_noise = 0.2
        self.noise_clip = 0.5
        self.policy_freq = 2
        self.total_it = 0
        self.exploration_noise = 0.1  # 新增探索噪声
​
    def select_action(self, state, add_noise=True):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        action = self.actor(state).cpu().data.numpy().flatten()
        if add_noise:
            noise = np.random.normal(0, self.exploration_noise, size=action_dim)
            action = (action + noise).clip(-self.max_action, self.max_action)
        return action
​
    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return
​
        self.total_it += 1
​
        batch = random.sample(self.replay_buffer, self.batch_size)
        state = torch.FloatTensor(np.array([t[0] for t in batch])).to(device)
        action = torch.FloatTensor(np.array([t[1] for t in batch])).to(device)
        reward = torch.FloatTensor(np.array([t[2] for t in batch])).reshape(-1, 1).to(device) / 10.0  # 奖励缩放
        next_state = torch.FloatTensor(np.array([t[3] for t in batch])).to(device)
        done = torch.FloatTensor(np.array([t[4] for t in batch])).reshape(-1, 1).to(device)
​
        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
            target_Q1 = self.critic1_target(next_state, next_action)
            target_Q2 = self.critic2_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + (1 - done) * self.gamma * target_Q
​
        current_Q1 = self.critic1(state, action)
        current_Q2 = self.critic2(state, action)
        critic1_loss = F.mse_loss(current_Q1, target_Q)
        critic2_loss = F.mse_loss(current_Q2, target_Q)
        self.critic1_optimizer.zero_grad()
        critic1_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic1.parameters(), 1.0)  # 梯度裁剪
        self.critic1_optimizer.step()
        self.critic2_optimizer.zero_grad()
        critic2_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), 1.0)
        self.critic2_optimizer.step()
​
        if self.total_it % self.policy_freq == 0:
            actor_loss = -self.critic1(state, self.actor(state)).mean()
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
            self.actor_optimizer.step()
​
            for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
​
    def save(self, filename):
        torch.save(self.actor.state_dict(), filename + "_actor.pth")
​
​
def train_td3(env, agent, episodes=2000, early_stop_threshold=-150):
    rewards_history = []
    moving_avg = []
    best_avg = -np.inf
​
    for ep in range(episodes):
        state,_ = env.reset()
        episode_reward = 0
        done = False
        step = 0
​
        while not done:
            # 线性衰减探索噪声
            if ep < 300:
                agent.exploration_noise = max(0.5 * (1 - ep / 300), 0.1)
            else:
                agent.exploration_noise = 0.1
​
            action = agent.select_action(state, add_noise=(ep < 100))  # 前100轮强制探索
            next_state, reward, done, _, _ = env.step(action)
            agent.replay_buffer.append((state, action, reward, next_state, done))
            state = next_state
            episode_reward += reward
            agent.train()
            step += 1
​
        rewards_history.append(episode_reward)
        current_avg = np.mean(rewards_history[-50:])
        moving_avg.append(current_avg)
​
        if current_avg > best_avg:
            best_avg = current_avg
            agent.save("td3_pendulum_best")
​
        if (ep + 1) % 50 == 0:
            print(f"Episode: {ep + 1}, Avg Reward: {current_avg:.2f}")
​
        # 早停机制
        if current_avg >= early_stop_threshold:
            print(f"早停触发,平均奖励达到 {current_avg:.2f}")
            break
​
    return moving_avg, rewards_history
​
​
# 训练并可视化
td3_agent = TD3(state_dim, action_dim, max_action)
moving_avg, rewards_history = train_td3(env, td3_agent, episodes=2000)
​
# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.6, label='single round reward')
plt.plot(moving_avg, 'r-', linewidth=2, label='moving average (50 rounds)')
plt.xlabel('episodes')
plt.ylabel('reward')
plt.title('TD3 training performance on Pendulum-v1')
plt.legend()
plt.grid(True)
plt.show()

三、代码解析

  1. Actor 和 Critic 网络

    • Actor 网络输出连续动作,通过 tanh 函数将动作限制在 −max_action,max_action 范围内。

    • Critic 网络输出状态-动作对的 Q 值。

  2. TD3 训练过程

    • 使用当前策略采样一批数据。

    • 使用目标网络计算目标 Q 值。

    • 更新 Critic 网络以最小化 Q 值的误差。

    • 延迟更新 Actor 网络以最大化 Q 值。

    • 更新目标网络。

  3. 训练过程

    • 在训练过程中,每 50 个 episode 打印一次平均奖励。

    • 训练结束后,绘制训练过程中的总奖励曲线。


四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。

  • 训练结束后,绘制训练过程中的总奖励曲线。


五、总结

本文介绍了 TD3 算法的基本原理,并使用 PyTorch 实现了一个简单的 TD3 模型来解决 Pendulum 问题。通过这个例子,我们学习了如何使用 TD3 算法进行连续动作空间的策略优化。

在下一篇文章中,我们将探讨更高级的强化学习算法,如 Soft Actor-Critic (SAC)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。

  • 如果你有 GPU,代码会自动检测并使用 GPU 加速。

希望这篇文章能帮助你更好地理解 TD3 算法!如果有任何问题,欢迎在评论区留言讨论。

相关推荐
仰泳的熊猫4 分钟前
1112 Stucked Keyboard
数据结构·c++·算法·pat考试
roman_日积跬步-终至千里9 分钟前
【计算机算法与设计(14)】例题五:最小生成树:Prim算法详细解释:π的含义、更新逻辑和选点原因
算法
让学习成为一种生活方式10 分钟前
压缩文件夹下下所有文件成压缩包tar.gz--随笔016
算法
嗷嗷哦润橘_15 分钟前
AI Agent学习:MetaGPT项目之RAG
人工智能·python·学习·算法·deepseek
不忘不弃33 分钟前
指针元素的使用
算法
江上鹤.14834 分钟前
Day37 MLP神经网络的训练
人工智能·深度学习·神经网络
he___H35 分钟前
滑动窗口一题
java·数据结构·算法·滑动窗口
AI科技星36 分钟前
统一场论质量定义方程:数学验证与应用分析
开发语言·数据结构·经验分享·线性代数·算法
ULTRA??38 分钟前
KD-Tree的查询原理
python·算法
java1234_小锋1 小时前
Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 残差连接(Residual Connection)详解以及算法实现
深度学习·语言模型·transformer