其实KL散度在这个游戏里的作用不大,游戏的action比较简单,不像LM里的action是一个很大的向量,可以直接用surr1,最大化surr1,实验测试确实是这样,而且KL的系数不能给太大,否则惩罚力度太大,action model 和ref model产生的action其实分布的差距并不太大
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pygame
import sys
from collections import deque
# 定义策略网络
class PolicyNetwork(nn.Module):
def __init__(self):
super(PolicyNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 2),
nn.Tanh(),
nn.Linear(2, 2), # CartPole的动作空间为2
nn.Softmax(dim=-1)
)
def forward(self, x):
return self.fc(x)
# 定义值网络
class ValueNetwork(nn.Module):
def __init__(self):
super(ValueNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(4, 2),
nn.Tanh(),
nn.Linear(2, 1)
)
def forward(self, x):
return self.fc(x)
# 经验回放缓冲区
class RolloutBuffer:
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
def store(self, state, action, reward, done, log_prob):
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.dones.append(done)
self.log_probs.append(log_prob)
def clear(self):
self.states = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
def get_batch(self):
return (
torch.tensor(self.states, dtype=torch.float),
torch.tensor(self.actions, dtype=torch.long),
torch.tensor(self.rewards, dtype=torch.float),
torch.tensor(self.dones, dtype=torch.bool),
torch.tensor(self.log_probs, dtype=torch.float)
)
# PPO更新函数
def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=100, gamma=0.99, clip_param=0.2):
states, actions, rewards, dones, old_log_probs = buffer.get_batch()
returns = []
advantages = []
G = 0
adv = 0
dones = dones.to(torch.int)
# print(dones)
for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))):
if done:
G = 0
adv = 0
G = reward + gamma * G #蒙特卡洛回溯G值
delta = reward + gamma * value.item() * (1 - done) - value.item() #TD差分
# adv = delta + gamma * 0.95 * adv * (1 - done) #
adv = delta + adv*(1-done)
returns.insert(0, G)
advantages.insert(0, adv)
returns = torch.tensor(returns, dtype=torch.float) #价值
advantages = torch.tensor(advantages, dtype=torch.float)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) #add baseline
for _ in range(epochs):
action_probs = policy_net(states)
dist = torch.distributions.Categorical(action_probs)
new_log_probs = dist.log_prob(actions)
ratio = (new_log_probs - old_log_probs).exp()
KL = new_log_probs.exp()*(new_log_probs - old_log_probs).mean() #KL散度 p*log(p/p')
#下面三行是核心
surr1 = ratio * advantages
PPO1,PPO2 = True,False
# print(surr1,KL*500)
if PPO1 == True:
actor_loss = -(surr1 - KL).mean()
if PPO2 == True:
surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
optimizer_policy.zero_grad()
actor_loss.backward()
optimizer_policy.step()
value_loss = (returns - value_net(states)).pow(2).mean()
optimizer_value.zero_grad()
value_loss.backward()
optimizer_value.step()
# 初始化环境和模型
env = gym.make('CartPole-v1')
policy_net = PolicyNetwork()
value_net = ValueNetwork()
optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4)
optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3)
buffer = RolloutBuffer()
# Pygame初始化
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()
draw_on = False
# 训练循环
state = env.reset()
for episode in range(10000): # 训练轮次
done = False
state = state[0]
step= 0
while not done:
step+=1
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action_probs = policy_net(state_tensor) #旧policy推理数据
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
log_prob = dist.log_prob(action)
next_state, reward, done, _ ,_ = env.step(action.item())
buffer.store(state, action.item(), reward, done, log_prob)
state = next_state
# 实时显示
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
if draw_on:
# 清屏并重新绘制
screen.fill((0, 0, 0))
cart_x = int(state[0] * 100 + 300) # 位置转换为屏幕坐标
pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30))
pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5)
pygame.display.flip()
clock.tick(60)
if step >2000:
draw_on = True
ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer)
buffer.clear()
state = env.reset()
print(f'Episode {episode} completed , reward: {step}.')
# 结束训练
env.close()
pygame.quit()
效果: