强化学习PPO/DDPG算法学习记录

强化学习PPO算法详解

核心思想

  • 直接学习一个策略函数pi(a|s), 在状态s下要输出的动作a的概率分布(离散情况下是每个action的概率,连续情况下是mean和std指定的高斯分布)
  • 策略梯度算法如果更新的步长太大,一次更新就能毁掉整个策略,所以通过一个裁剪函数,防止新策略和旧策略差距太大,保证"小幅且安全"。
  • 多步更新:在收集一批数据后,用小批量数据对策略进行多次epochs更新,提高样本效率。

参考代码:

python 复制代码
这个PPO实现代码有什么问题?里面好像没有用到critic网络的结果:import torch
import torch.nn as nn
import torch.optim as optim

class PPO(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PPO, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, action_dim),
            nn.Softmax(dim=-1)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )
        
    def forward(self, state):
        return self.actor(state), self.critic(state)

class PPOAgent:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, epsilon=0.2):
        self.ppo = PPO(state_dim, action_dim)
        self.optimizer = optim.Adam(self.ppo.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon

	def update(self, states, actions, rewards, next_states, dones):
	    states = torch.FloatTensor(states)
	    actions = torch.LongTensor(actions)
	    rewards = torch.FloatTensor(rewards)
	    next_states = torch.FloatTensor(next_states)
	    dones = torch.FloatTensor(dones)
	
	    # 1. 首先,用当前的策略网络(不计算梯度)计算旧概率 old_probs
	    with torch.no_grad():
	        old_probs, old_state_values = self.ppo(states)
	        old_probs = old_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
	        # 计算下一个状态的价值
	        _, next_state_values = self.ppo(next_states)
	        # 计算价值目标:如果done了,next_state_value就是0
	        value_targets = rewards + self.gamma * next_state_values.squeeze(1) * (1 - dones)
	        # 计算优势函数 A(s,a) = value_target - old_state_value
	        advantages = value_targets - old_state_values.squeeze(1)
	        # 通常会对advantages进行标准化,以减少方差
	        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
	
	    for _ in range(10):  # 多次更新
	        # 2. 用当前策略网络计算新概率和状态价值
	        new_probs, state_values = self.ppo(states)
	        new_probs = new_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
	
	        # 3. 计算重要性采样比率
	        ratio = new_probs / old_probs
	
	        # 4. 计算Clipped Surrogate Loss
	        surr1 = ratio * advantages
	        surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
	        actor_loss = -torch.min(surr1, surr2).mean()
	
	        # 5. 计算Critic Loss (MSE between value_targets and current value estimates)
	        critic_loss = nn.MSELoss()(state_values.squeeze(1), value_targets)
	
	        # 6. 总损失
	        loss = actor_loss + 0.5 * critic_loss
	
	        self.optimizer.zero_grad()
	        loss.backward()
	        self.optimizer.step()
    def get_action(self, state):
        state = torch.FloatTensor(state)
        probs, _ = self.ppo(state)
        return torch.multinomial(probs, 1).item()

PPO是在线策略,输出的是概率,更新稳健,是策略网络的集大成者。

DDPG算法详解

核心改进

相比于DQN,DDPG的核心改进在于:

  • DQN的动作空间是离散的,例如上下左右开火等,而DDPG的动作空间是连续的
  • DQN输出的是Q值,然后选择最大的对应的动作,DDPG直接输出动作
  • DQN是value-based,而DDPG是Policy- based
  • DQN通常只有一个Q网络,DDPG要有Actor和critic两个网络

一句话总结:DQN是为了解决离散控制问题,DDPG主要是针对连续控制领域的,是DQN和策略网络的结合。

参考代码(原文章:Deep Reinforcement Learning (DRL) 算法在 PyTorch 中的实现与应用):

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, action_dim)
        self.max_action = max_action
        
    def forward(self, state):
        a = torch.relu(self.fc1(state))
        a = torch.relu(self.fc2(a))
        return self.max_action * torch.tanh(self.fc3(a))

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, 1)
        
    def forward(self, state, action):
        q = torch.cat([state, action], 1)
        q = torch.relu(self.fc1(q))
        q = torch.relu(self.fc2(q))
        return self.fc3(q)

class DDPGAgent:
    def __init__(self, state_dim, action_dim, max_action, lr=1e-4, gamma=0.99, tau=0.001):
        self.actor = Actor(state_dim, action_dim, max_action)
        self.actor_target = Actor(state_dim, action_dim, max_action)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        
        self.critic = Critic(state_dim, action_dim)
        self.critic_target = Critic(state_dim, action_dim)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        
        self.gamma = gamma
        self.tau = tau

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1))
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, replay_buffer, batch_size=100):
        # 从经验回放中采样
        state, action, next_state, reward, done = replay_buffer.sample(batch_size)
        
        # 计算目标Q值
        target_Q = self.critic_target(next_state, self.actor_target(next_state))
        target_Q = reward + (1 - done) * self.gamma * target_Q.detach()
        
        # 更新Critic
        current_Q = self.critic(state, action)
        critic_loss = nn.MSELoss()(current_Q, target_Q)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # 更新Actor
        actor_loss = -self.critic(state, self.actor(state)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # 软更新目标网络
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
相关推荐
弈宸3 天前
Transformer与ViT
算法·架构
蛋仔聊测试3 天前
pytest源码解析(三) 解析pytest 插件系统
python·测试
databook3 天前
Manim实现水波纹特效
后端·python·动效
跟橙姐学代码3 天前
Python 调试的救星:pdb 帮你摆脱“打印地狱”
前端·pytorch·python
倔强青铜三4 天前
苦练Python第48天:类的私有变量“防身术”,把秘密藏进类里!
人工智能·python·面试
倔强青铜三4 天前
苦练Python第47天:一文吃透继承与多继承,MRO教你不再踩坑
人工智能·python·面试
倔强青铜三4 天前
为什么Python程序员必须学习Pydantic?从数据验证到API开发的革命性工具
人工智能·python·面试
豌豆花下猫4 天前
Python 潮流周刊#120:新型 Python 类型检查器对比(摘要)
后端·python·ai
知其然亦知其所以然4 天前
国产大模型也能无缝接入!Spring AI + 智谱 AI 实战指南
java·后端·算法
然我4 天前
搞定异步任务依赖:Promise.all 与拓扑排序的妙用
前端·javascript·算法