强化学习PPO/DDPG算法学习记录

强化学习PPO算法详解

核心思想

  • 直接学习一个策略函数pi(a|s), 在状态s下要输出的动作a的概率分布(离散情况下是每个action的概率,连续情况下是mean和std指定的高斯分布)
  • 策略梯度算法如果更新的步长太大,一次更新就能毁掉整个策略,所以通过一个裁剪函数,防止新策略和旧策略差距太大,保证"小幅且安全"。
  • 多步更新:在收集一批数据后,用小批量数据对策略进行多次epochs更新,提高样本效率。

参考代码:

python 复制代码
这个PPO实现代码有什么问题?里面好像没有用到critic网络的结果:import torch
import torch.nn as nn
import torch.optim as optim

class PPO(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PPO, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
        
    def forward(self, state):
        return self.actor(state), self.critic(state)

class PPOAgent:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2):
        self.ppo = PPO(state_dim, action_dim)
        self.optimizer = optim.Adam(self.ppo.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon

	def update(self, states, actions, rewards, next_states, dones):
	    states = torch.FloatTensor(states)
	    actions = torch.LongTensor(actions)
	    rewards = torch.FloatTensor(rewards)
	    next_states = torch.FloatTensor(next_states)
	    dones = torch.FloatTensor(dones)
	
	    # 1. 首先,用当前的策略网络(不计算梯度)计算旧概率 old_probs
	    with torch.no_grad():
	        old_probs, old_state_values = self.ppo(states)
	        old_probs = old_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
	        # 计算下一个状态的价值
	        _, next_state_values = self.ppo(next_states)
	        # 计算价值目标:如果done了,next_state_value就是0
	        value_targets = rewards + self.gamma * next_state_values.squeeze(1) * (1 - dones)
	        # 计算优势函数 A(s,a) = value_target - old_state_value
	        advantages = value_targets - old_state_values.squeeze(1)
	        # 通常会对advantages进行标准化,以减少方差
	        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
	
	    for _ in range(10):  # 多次更新
	        # 2. 用当前策略网络计算新概率和状态价值
	        new_probs, state_values = self.ppo(states)
	        new_probs = new_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
	
	        # 3. 计算重要性采样比率
	        ratio = new_probs / old_probs
	
	        # 4. 计算Clipped Surrogate Loss
	        surr1 = ratio * advantages
	        surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
	        actor_loss = -torch.min(surr1, surr2).mean()
	
	        # 5. 计算Critic Loss (MSE between value_targets and current value estimates)
	        critic_loss = nn.MSELoss()(state_values.squeeze(1), value_targets)
	
	        # 6. 总损失
	        loss = actor_loss + 0.5 * critic_loss
	
	        self.optimizer.zero_grad()
	        loss.backward()
	        self.optimizer.step()
    def get_action(self, state):
        state = torch.FloatTensor(state)
        probs, _ = self.ppo(state)
        return torch.multinomial(probs, 1).item()

PPO是在线策略,输出的是概率,更新稳健,是策略网络的集大成者。

DDPG算法详解

核心改进

相比于DQN,DDPG的核心改进在于:

  • DQN的动作空间是离散的,例如上下左右开火等,而DDPG的动作空间是连续的
  • DQN输出的是Q值,然后选择最大的对应的动作,DDPG直接输出动作
  • DQN是value-based,而DDPG是Policy- based
  • DQN通常只有一个Q网络,DDPG要有Actor和critic两个网络

一句话总结:DQN是为了解决离散控制问题,DDPG主要是针对连续控制领域的,是DQN和策略网络的结合。

参考代码(原文章:Deep Reinforcement Learning (DRL) 算法在 PyTorch 中的实现与应用):

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, action_dim)
        self.max_action = max_action
        
    def forward(self, state):
        a = torch.relu(self.fc1(state))
        a = torch.relu(self.fc2(a))
        return self.max_action * torch.tanh(self.fc3(a))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, 1)
        
    def forward(self, state, action):
        q = torch.cat([state, action], 1)
        q = torch.relu(self.fc1(q))
        q = torch.relu(self.fc2(q))
        return self.fc3(q)

class DDPGAgent:
    def __init__(self, state_dim, action_dim, max_action, lr=1e-4, gamma=0.99, tau=0.001):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.actor_target = Actor(state_dim, action_dim, max_action)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        
        self.critic = Critic(state_dim, action_dim)
        self.critic_target = Critic(state_dim, action_dim)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        
        self.gamma = gamma
        self.tau = tau

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1))
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, replay_buffer, batch_size=100):
        # 从经验回放中采样
        state, action, next_state, reward, done = replay_buffer.sample(batch_size)
        
        # 计算目标Q值
        target_Q = self.critic_target(next_state, self.actor_target(next_state))
        target_Q = reward + (1 - done) * self.gamma * target_Q.detach()
        
        # 更新Critic
        current_Q = self.critic(state, action)
        critic_loss = nn.MSELoss()(current_Q, target_Q)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # 更新Actor
        actor_loss = -self.critic(state, self.actor(state)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # 软更新目标网络
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
相关推荐
念念01072 小时前
PyTorch
人工智能·pytorch·python
WSSWWWSSW3 小时前
Python OpenCV图像处理与深度学习:Python OpenCV性能优化与高效图像处理
图像处理·python·opencv
ILUUSION_S3 小时前
ReactAgent接入MCP服务工具
python·langchain
胖达不服输3 小时前
「日拱一码」076 深度学习——自然语言处理NLP
人工智能·python·深度学习·自然语言处理·nlp
LeonDL1683 小时前
基于YOLO11深度学习的植物叶片及缺陷检测系统【Python源码+Pyqt5界面+数据集+安装使用教程+训练代码】【附下载链接】
python·深度学习·yolo目标检测·yolov5数据集·yolov8数据集·yolo11数据集·植物叶片及缺陷检测系统
BillKu3 小时前
Spring Boot 后端接收多个文件的方法
spring boot·后端·python
闻缺陷则喜何志丹3 小时前
【逆序对 博弈】P10737 [SEERC 2020] Reverse Game|普及+
c++·算法·洛谷·博弈·逆序堆
hui函数3 小时前
订单后台管理系统-day07菜品模块
数据库·后端·python·flask
小李小李无与伦比4 小时前
MinerU环境部署——PDF转Markdown
开发语言·python·深度学习·conda