强化学习13——Actor-Critic算法

Actor-Critic算法结合了策略梯度和值函数的优点,我们将其分为两部分,Actor(策略网络)和Critic(价值网络)

  • Actor与环境交互,在Critic价值函数的指导下使用策略梯度学习好的策略
  • Critic通过Actor与环境交互收集的数据学习,得到一个价值函数,来判断当前状态哪些动作是好,哪些动作是坏,进而帮Actor进行策略更新。

A2C算法

AC算法的目的是为了消除策略梯度算法的高仿查问题,可以引用优势函数(advantage function) A π ( s t , a t ) A^{\pi}(s_t,a_t) Aπ(st,at) ,来表示当前当前状态-动作对 相对于平均水平的优势:
A π ( s t , a t ) = Q π ( s t , a t ) − V π ( s t ) A^{\pi}(s_t,a_t)=Q^{\pi}(s_t,a_t)-V^{\pi}(s_t) Aπ(st,at)=Qπ(st,at)−Vπ(st)

通过与平均水平相减,可以降低方差。但需要注意的是,相减的是 V π ( s t ) V^{\pi}(s_t) Vπ(st) ,即在状态 s t s_t st 下的价值,即状态 s t s_t st 的回报的均值,而不是所有状态 s s s 的回报的均值。

可以将目标函数改为:
∇ θ J ( θ ) ∝ E π θ [ A π ( s t , a t ) ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \nabla_\theta J(\theta)\propto\mathbb{E}{\pi\theta}\left[A^\pi(s_t,a_t)\nabla_\theta\log\pi_\theta(a_t\mid s_t)\right] ∇θJ(θ)∝Eπθ[Aπ(st,at)∇θlogπθ(at∣st)]

这就是A2C算法(Advantage Actor-Critic)算法。脱胎于A3C算法,即增加了多个进程,每一个进程都拥有一个独立的网络和环境以供训练。

广义优势估计

时序差分能有效解决高方差问题但是是有偏估计,而蒙特卡洛是无偏估计但是会带来高方差问题,因此通常会结合这两个方法形成一种新的估计方式,即 T D ( λ ) TD(\lambda) TD(λ) 估计,通过结合多步,形成新的估计方式,成为广义优势估计(generalized advantage estimation GAE)。

A GAE ( γ , λ ) ( s t , a t ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l = ∑ l = 0 ∞ ( γ λ ) l ( r t + l + γ V π ( s t + l + 1 ) − V π ( s t + l ) ) \begin{aligned} A^{\text{GAE}(\gamma,\lambda)}(s_t,a_t)& =\sum_{l=0}^\infty(\gamma\lambda)^l\delta_{t+l} \\ &=\sum_{l=0}^\infty(\gamma\lambda)^l\left(r_{t+l}+\gamma V^\pi(s_{t+l+1})-V^\pi(s_{t+l})\right) \end{aligned} AGAE(γ,λ)(st,at)=l=0∑∞(γλ)lδt+l=l=0∑∞(γλ)l(rt+l+γVπ(st+l+1)−Vπ(st+l))

其中, δ t + l \delta_{t+l} δt+l 为时步 t + l t+l t+l 的TD误差,为:
δ t + l = r t + l + γ V π ( s t + l + 1 ) − V π ( s t + l ) \delta_{t+l}=r_{t+l}+\gamma V^{\pi}(s_{t+l+1})-V^{\pi}(s_{t+l}) δt+l=rt+l+γVπ(st+l+1)−Vπ(st+l)

当 λ = 0 \lambda=0 λ=0 时,退化为单步TD误差:
A G A E ( γ , 0 ) ( s t , a t ) = δ t = r t + γ V π ( s t + 1 ) − V π ( s t ) A^{\mathrm{GAE}(\gamma,0)}(s_t,a_t)=\delta_t=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t) AGAE(γ,0)(st,at)=δt=rt+γVπ(st+1)−Vπ(st)

当 λ = 1 \lambda=1 λ=1 时,则为蒙特卡洛估计:
A G A E ( γ , 1 ) ( s t , a t ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l = ∑ l = 0 ∞ ( γ ) l δ t + l A^{\mathrm{GAE}(\gamma,1)}(s_t,a_t)=\sum_{l=0}^\infty(\gamma\lambda)^l\delta_{t+l}=\sum_{l=0}^\infty(\gamma)^l\delta_{t+l} AGAE(γ,1)(st,at)=l=0∑∞(γλ)lδt+l=l=0∑∞(γ)lδt+l

代码实操

python 复制代码
import gymnasium as gym
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils
python 复制代码
# 定义策略网络
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)

# 定义价值网络,输出一个价值,为一维张量
class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

现在定义A2C算法的主题,包括采取动作和更新网络参数的两个函数。

python 复制代码
class ActorCritic:
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 gamma, device):
        # 策略网络
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)  # 价值网络
        # 策略网络优化器
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=critic_lr)  # 价值网络优化器
        self.gamma = gamma
        self.device = device
        
    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()
    
    def update(self,transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
            self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        
        # 时序差分目标
        td_target=rewards+self.gamma*self.critic(next_states)*(1-dones)
        # 进行时序擦划分
        td_delta=td_target-self.critic(states)
        log_probs=torch.log(self.actor(states).gather(1,actions))
        actor_loss=torch.mean(-log_probs*td_delta.detach())
        # 均方误差损失函数
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()  # 计算策略网络的梯度
        critic_loss.backward()  # 计算价值网络的梯度
        self.actor_optimizer.step()  # 更新策略网络的参数
        self.critic_optimizer.step()  # 更新价值网络的参数
    
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                    gamma, device)

return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Actor-Critic on {}'.format(env_name))
plt.show()
  state = torch.tensor([state], dtype=torch.float).to(self.device)
Iteration 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 25.55it/s, episode=100, return=20.400]
Iteration 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 24.48it/s, episode=200, return=51.200]
Iteration 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:14<00:00,  6.91it/s, episode=300, return=151.500]
Iteration 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:25<00:00,  3.88it/s, episode=400, return=256.700]
Iteration 4:  53%|███████████████████████████████████████████████████████████████████████████████▌                                                                      | 53/100 [00:17<00:10,  4.51it/s, episode=450, return=235.500]
相关推荐
chenziang131 分钟前
leetcode hot 100 二叉搜索
数据结构·算法·leetcode
single5942 小时前
【c++笔试强训】(第四十五篇)
java·开发语言·数据结构·c++·算法
呆头鹅AI工作室3 小时前
基于特征工程(pca分析)、小波去噪以及数据增强,同时采用基于注意力机制的BiLSTM、随机森林、ARIMA模型进行序列数据预测
人工智能·深度学习·神经网络·算法·随机森林·回归
一勺汤3 小时前
YOLO11改进-注意力-引入自调制特征聚合模块SMFA
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·目标跟踪
每天写点bug4 小时前
【golang】map遍历注意事项
开发语言·算法·golang
程序员JerrySUN4 小时前
BitBake 执行流程深度解析:从理论到实践
linux·开发语言·嵌入式硬件·算法·架构
王老师青少年编程4 小时前
gesp(二级)(16)洛谷:B4037:[GESP202409 二级] 小杨的 N 字矩阵
数据结构·c++·算法·gesp·csp·信奥赛
robin_suli5 小时前
动态规划子序列问题系列一>等差序列划分II
算法·动态规划
cxylay6 小时前
自适应滤波算法分类及详细介绍
算法·分类·自适应滤波算法·自适应滤波·主动噪声控制·anc
茶猫_6 小时前
力扣面试题 - 40 迷路的机器人 C语言解法
c语言·数据结构·算法·leetcode·机器人·深度优先