基于强化学习算法玩CartPole游戏

什么事CartPole游戏

CartPole(也称为倒立摆问题)是一个经典的控制理论和强化学习的基础问题,通常用于测试和验证控制算法的性能。具体来说,它是一个简单的物理模拟问题,其目标是通过在一个平衡杆(倒立摆)上安装在小车(或称为平衡车)上的水平移动,使杆子保持竖直直立的状态。

有两个动作(action):

左移(0)

右移(1)

四个状态(state): 1. 小车在轨道上的位置 2. 杆子与竖直方向的夹角 3. 小车速度 4. 角度变化率

神经网络设计

1、强化学习的训练网络cartpole_train.py

复制代码
import  gym
import pygame
import time
import random
import torch
from torch.distributions import Categorical

from torch import nn, optim
import torch.nn.functional as F

def compute_policy_loss(n, log_p):
    r = list()
    #构造奖励r列表
    for i in range(n, 0 ,-1):
        r.append(i *1.0)
    r = torch.tensor(r)
    r = (r - r.mean()) / r.std() #进行标准化处理
    loss = 0
    #计算损失函数
    for pi, ri in zip(log_p, r):
        loss += -pi * ri
    return  loss

class CartPolePolicy(nn.Module):
    def __init__(self):
        super(CartPolePolicy, self).__init__()
        self.fc1 = nn.Linear(in_features = 4, out_features = 128)
        self.fc2 = nn.Linear(128, 2) #输出为神经元个数为2表示,向左和向向右
        self.drop = nn.Dropout(p=0.6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.drop(x)
        x = F.relu(x)
        x = self.fc2(x)
        #使用softmax决策最终的行动,是向左还是右
        return F.softmax(x, dim=1)


if __name__ == '__main__':
    env = gym.make("CartPole-v1") #启动环境
    env.reset(seed= 543)
    torch.manual_seed(543)
    policy = CartPolePolicy() #定义模型
    optimizer = optim.Adam(policy.parameters(), lr = 0.01) #优化器

    #我们一共最多训练1000个回合
    #每个回合最多行动10000次
    #当某一回合的游戏步数超过5000时,就认为完成训练
    max_episod = 1000 #最大游戏回合数
    max_action = 10000 #每回合最大行动数
    max_steps = 5000 #完成训练的步数
    for episod in range(1, max_episod + 1):
        # 对于每一轮循环,都要重新启动一次游戏环境
        state, _ = env.reset()
        step = 0
        log_p = list()
        for step in range(1, max_action + 1):
            state = torch.from_numpy(state).float().unsqueeze(0)
            probs = policy(state) #计算神经网络给出的行动概率
            # 基于网络给出的概率分布,随机选择行动
            m = Categorical(probs)
            # 这里并不是直接使用概率较大的行动,而是通过概率分布生成action, 这样可以进一步探索低概率行动
            action = m.sample()
            state, _, done, _, _ = env.step(action.item())
            if done:
                break #表示跳出该for循环
            log_p.append(m.log_prob(action)) #保存每次行动对应的概率分布
        if step > max_steps: #当step大于最大步数时
            print(f"Done! last episode {episod} Run steps {step}")
            break #跳出循序,结束训练

        #每一回合游戏,都会做一次梯度下降算法
        optimizer.zero_grad()
        loss = compute_policy_loss(step, log_p)
        loss.backward()
        optimizer.step()
        if episod % 10 ==0:
            print(f"Episode {episod} Run step {step}")
    #保存模型
    torch.save(policy.state_dict(), f"cartpole_policy.pth")

2、验证:cartpole_eval.py

复制代码
import  gym
import pygame
import torch.nn as nn
import torch.nn.functional as F
import time
import torch
class CartPolePolicy(nn.Module):
    def __init__(self):
        super(CartPolePolicy, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)
        self.drop = nn.Dropout(p=0.6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.drop(x)
        x = F.relu(x)
        x = self.fc2(x)
        return F.softmax(x, dim=1)


if __name__ == '__main__':
    pygame.init() #初始化pygame
    #使用gym, 创建一个artPole游戏的运行环境,这个环境是提供给人类玩家使用的
    env = gym.make('CartPole-v1', render_mode = "human")
    state, _ =env.reset()
    #使用env.reset重置环境后,会得到CartPole游戏中关键参数state
    cart_position = state[0] #小车位置
    cart_speed = state[1] #小车速度
    pole_angle = state[2] #杆的角度
    pole_speed = state[3] #杆的尖端速度

    #加载网络
    policy = CartPolePolicy()
    policy.load_state_dict(torch.load("cartpole_policy.pth"))
    policy.eval()

    start_time =time.time()
    max_action =1000 #设置游戏最大执行次数
    #最多执行1000次方向键,游戏就可以通关结束
    step = 0
    fail = False
    for step in range(1, max_action + 1):
        #首先使用time.sleep,使游戏暂停0.3s,用于人的反应,觉得自己反应慢可以设置更长时间
        # time.sleep(0.3)
        #小车的控制方式,通过神经网络,来决定小车的运动方向
        #将环境参数state转为张量
        state = torch.from_numpy(state).float().unsqueeze(0)
        #输入至网络模型,计算行动概率probs
        probs = policy(state)
        #选取行动概率最大的行动
        action =torch.argmax(probs, dim = 1).item()
        state, _, done, _, _ = env.step(action) #done为True,表示杆倒了
        if done:
            fail = True
            break
        print(f"step = {step} action = {action} angle = {state[2]:.2f}  position = {state[0]:.2f}")

    end_time = time.time()
    game_time = end_time - start_time
    if fail:
        print(f"Game over ,you play {game_time:.2f} seconds, {step} steps.")
    else:
        print(f"Congratulations! you play  {game_time:.2f} seconds, {step} steps.")
    env.close()

视频讲解:

什么是reinforce强化学习算法,基于强化学习玩CartPole游戏_哔哩哔哩_bilibili

相关推荐
仙人掌_lz1 天前
深度理解用于多智能体强化学习的单调价值函数分解QMIX算法:基于python从零实现
python·算法·强化学习·rl·价值函数
Mr.Winter`2 天前
深度强化学习 | 图文详细推导软性演员-评论家SAC算法原理
人工智能·深度学习·神经网络·机器学习·数据挖掘·机器人·强化学习
IT猿手3 天前
基于强化学习 Q-learning 算法求解城市场景下无人机三维路径规划研究,提供完整MATLAB代码
神经网络·算法·matlab·人机交互·无人机·强化学习·无人机三维路径规划
仙人掌_lz4 天前
理解多智能体深度确定性策略梯度MADDPG算法:基于python从零实现
python·算法·强化学习·策略梯度·rl
仙人掌_lz5 天前
深入理解深度Q网络DQN:基于python从零实现
python·算法·强化学习·dqn·rl
IT猿手6 天前
基于 Q-learning 的城市场景无人机三维路径规划算法研究,可以自定义地图,提供完整MATLAB代码
深度学习·算法·matlab·无人机·强化学习·qlearning·无人机路径规划
Two summers ago6 天前
arXiv2025 | TTRL: Test-Time Reinforcement Learning
论文阅读·人工智能·机器学习·llm·强化学习
仙人掌_lz7 天前
为特定领域微调嵌入模型:打造专属的自然语言处理利器
人工智能·ai·自然语言处理·embedding·强化学习·rl·bge
碣石潇湘无限路9 天前
【AI】基于生活案例的LLM强化学习(入门帖)
人工智能·经验分享·笔记·生活·openai·强化学习
人类发明了工具9 天前
【强化学习】强化学习算法 - 多臂老虎机问题
机器学习·强化学习·多臂老虎机