ppo算法简单实现

导入必要的库

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gym
import numpy as np

超参数

python 复制代码
# --- 超参数 ---
learning_rate = 0.002
gamma         = 0.99
lmbda         = 0.95 # GAE 参数
eps_clip      = 0.1  # PPO 剪切范围
K_epochs      = 3    # 同一批数据重复训练次数
T_horizon     = 20   # 步长周期

定义PPO模型

python 复制代码
class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        self.data = []
        
        self.fc1   = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)  # 策略头 (Actor)
        self.fc_v  = nn.Linear(256, 1)  # 价值头 (Critic)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, softmax_dim=0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        probs = F.softmax(x, dim=softmax_dim)
        return probs
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
      
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, done = transition
            s_lst.append(s); a_lst.append([a]); r_lst.append([r])
            s_prime_lst.append(s_prime); prob_a_lst.append([prob_a]); done_lst.append([done])
            
        s, a, r, s_prime, done, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                        torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
                                        torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
        self.data = []
        return s, a, r, s_prime, done, prob_a

    def train_net(self):
        s, a, r, s_prime, done, prob_a = self.make_batch()

        for i in range(K_epochs):
            # 计算 TD Target 和 Advantage (GAE 简化版)
            td_target = r + gamma * self.v(s_prime) * (1 - done)
            delta = td_target - self.v(s)
            delta = delta.detach().numpy()

            advantage_lst = []
            adv = 0.0
            for delta_t in delta[::-1]:
                adv = gamma * lmbda * adv + delta_t[0]
                advantage_lst.append([adv])
            advantage_lst.reverse()
            advantage = torch.tensor(advantage_lst, dtype=torch.float)

            # 计算 Ratio
            pi = self.pi(s, softmax_dim=1)
            pi_a = pi.gather(1, a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a)) 

            # PPO 核心损失函数
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach())

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

主循环

python 复制代码
# --- 主循环 ---
def main():
    env = gym.make('CartPole-v1')
    model = PPO()
    score = 0.0

    for n_epi in range(1000):
        s = env.reset()[0] if isinstance(env.reset(), tuple) else env.reset()
        done = False
        while not done:
            for t in range(T_horizon):
                prob = model.pi(torch.from_numpy(s).float())
                m = torch.distributions.Categorical(prob)
                a = m.sample().item()
                step_result = env.step(a)
                if len(step_result) == 5:
                    s_prime, r, terminated, truncated, info = step_result
                    done = terminated or truncated
                else:
                    s_prime, r, done, info = step_result

                model.put_data((s, a, r/100.0, s_prime, prob[a].item(), done))
                s = s_prime
                score += r
                if done: break
            
            model.train_net()

        if n_epi % 20 == 0 and n_epi != 0:
            print(f"# Episode: {n_epi}, Avg Score: {score/20}")
            score = 0.0
    env.close()

if __name__ == '__main__':
    main()
相关推荐
Hello:CodeWorld几秒前
Dify 从入门到实战:部署、模型对接与企业级 AI 应用开发全教程
人工智能·python·架构·ai编程
AllData公司负责人6 分钟前
大模型赋能AllData数据中台,系列升级|通过联合智谱大模型与Chat2DB开源项目,建设Text2SQL生产场景全新体验的数据源平台!
数据库·人工智能·text2sql·数据中台·数据源·chat2db·智谱大模型
xinlianyq11 分钟前
2026 电商视觉红海突围:核心 AI 视频与海报创作工具实战选型指南
人工智能·aigc
Deepoch14 分钟前
Deepoc VLA开发板:除草机器人的持续学习与协同作业系统
人工智能·学习·机器人·开发板·具身模型·deepoc
生成论实验室19 分钟前
判断力与六十四卦:AI的第三块基石
人工智能·语言模型·机器人·自动驾驶·安全架构
xixixi7777722 分钟前
空天地通信、高速光模块、AI 智能体攻击、同态加密芯片四大事件解读:AI 算力底座攻防与全域通信同步升级
大数据·人工智能·深度学习·ai·大模型·光模块·智能体
水木流年追梦29 分钟前
大模型入门-大模型优化方法13- MTP 多 token 输出、DCA 双块注意力
人工智能·分布式·算法·正则表达式·prompt
雪隐29 分钟前
AI股票小助手06-Backtrader 量化回测
人工智能·后端
蓝桉~MLGT33 分钟前
语音陪伴助手
人工智能·语音识别
数据皮皮侠34 分钟前
全国消协智慧 315 平台投诉信息数据库
大数据·人工智能·算法·百度·制造