强化学习实战8.3——用PPO打赢星际争霸【编写自定义环境GYM】

编写自定义环境

我们已经写完下位机的脚本了,现在回过头来继续写上位机的内容。

还记得gym的环境要自实现四个函数step() render() close() reset()

step函数的编写

我们刚刚是从step()跳出来的,因此要回去写完step()

step()需要返回这些内容,因此我们需要在上位机把transaction的对应值return出来。

|------------------------------|------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| .step(action) (执行动作) | 作用 :这是强化学习的核心。将 Agent 的动作(Action)传入环境,环境反馈下一步的状态和奖励。 参数action (Agent 选择的动作)。 | 1. observation :执行动作后的新环境状态。 2. reward :该动作获得的奖励(浮点数)。 3. terminated :布尔值。True 表示回合正常结束(如到达目标/坠毁)。 4. truncated :布尔值。True 表示回合被强制截断(如超时/出界)。 5. info:辅助诊断信息。 |

python 复制代码
    def step(self,action):
        while True:
                
            try:
                with open('transaction.pkl','rb') as f:#先读
                    transaction=pickle.load(f)
                if transaction['action'] is None:#如果没有记录
                    transaction['action']=action
                    with open('transaction.pkl','wb') as f:
                        pickle.dump(transaction,f)
                    break
            except Exception as e:
                time.sleep(0.05)


        while True:
            try:
                with open('transaction.pkl', 'rb') as f:
                    transaction = pickle.load(f)
                if transaction['action'] is None:
                    break
            except:
                time.sleep(0.05)

回顾一下,step的逻辑,首先先看看transaction.pkl是否有记录,如果没有记录,就写入action,然后交给下位机读取,下位机读取后会修改action为None。因此上位机在提交action后立刻break,进入第二个while阻塞态,当发现transaction.pkl的action被下位机赋None后break,进入下一步,重新进入第一个while。

那么我们要获取observation\reward这类信息,就需要在第二个while中接收。

python 复制代码
    def step(self,action):
        while True:
                
            try:
                with open('transaction.pkl','rb') as f:#先读
                    transaction=pickle.load(f)
                if transaction['action'] is None:#如果没有记录
                    transaction['action']=action
                    with open('transaction.pkl','wb') as f:
                        pickle.dump(transaction,f)
                    break
            except Exception as e:
                time.sleep(0.05)


        while True:
            try:
                with open('transaction.pkl', 'rb') as f:
                    transaction = pickle.load(f)
                if transaction['action'] is None:
                    observation=transaction['observation']
                    reward=transaction['reward']
                    terminated=transaction['terminated']
                    truncated=transaction['truncated']
                    break
            except:
                time.sleep(0.05)

        info={}
        return observation,reward,terminated,truncated,info

reset函数的编写

下位机的代码需要由上位机启动,相当于是一个子线程,因此需要导入子线程库,并在reset中启动下位机脚本。

python 复制代码
import subprocess

然后在reset函数中启动子线程:

切记要用shell=True否则无法打开星际争霸的窗口。

python 复制代码
    def reset(self,seed=None, options=None):
        print('reset the Env')
        map=np.zeros((244,244,3),dtype=np.uint8)
        observation=map
        transaction={'observation':map,'reward':0,'action':None,'terminated':False,'truncated':False}

        with open('transaction.pkl','wb') as f:
            pickle.dump(transaction,f)


        #subprocess.Popen(['Python3','WorkerRushBot.py'],creationflags=subprocess.CREATE_NEW_CONSOLE)  
        subprocess.Popen(
            [
                "cmd", "/c", "start",
                "python", "WorkerRushBot.py"
            ],
            shell=True
        )
        print('clear')
        info={}
        return observation,info

为了验证我们的环境是否符合Gym规范,在SB3官网,我们可以使用check_env验证。

https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html

python 复制代码
from stable_baselines3.common.env_checker import check_env
check_env(sc2)

测试没报错就删掉这行,否则会开多个shell。之后用reset调用。

封装脚本

把自定义环境的代码封装成"StarCraft2Env.py"的代码,路径和当前jupyter文件一致。

代码有所改动,完整文件在文末获取。

强化学习PPO环境的编写

我们已经完成自定义环境【上位机】和操作脚本【下位机】的编写了,接下来到了最后一步,如何利用SB3套用自定义的环境来训练。

我们在jupyter再创建一个SC2_Training的文档。

我们计划用PPO来训练,这是一种基于优化策略的模型,具体原理可以在我之前的强化学习理论找到。

可以在SB3官网找到PPO的文档

https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

我们的训练需要基于刚刚保存的Starcraft2Env,因此要把这个py文件作为库导入

python 复制代码
# 1. 导入依赖库
from stable_baselines3 import PPO
import os
import time
from StarCraft2Env import StarCraft2Env

由于我们使用的是图像输入,训练量也非常大,中途可能出现断电等情况。因此我们需要每隔一定batch保存一个模型(也可以使用callbacks,往期讲过就不再提了)

python 复制代码
model_name = f'{int(time.time())}'
model_dir = f'models/{model_name}/'
logs_dir = f'logs/{model_name}/'

if not os.path.exists(model_dir):
    os.makedirs(model_dir)
if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)
  • 核心设计 :用时间戳作为模型名,保证每次训练的模型 / 日志目录唯一,不会覆盖历史训练数据
  • model_dir:保存训练好的 PPO 模型权重,用于后续加载、测试、迭代
  • logs_dir:保存 TensorBoard 训练日志,用于可视化训练曲线(奖励、胜率、损失等)
  • 自动创建目录,避免因目录不存在导致的报错

log是用于在tensorboard查看训练曲线的,不赘述。

然后要选PPO的策略了,我们选CNN,老师选了MLP,我不理解。

我们就用CNN,图像输入不用CNN难到用全连接?244*244*3都多大了,训个der?

然后CNN不能直接用,因为默认的cnnpolicy特征提取器是做了池化的,池化就是区域取平均/最值,严重损失细节,就这么小一张图,单位本来就小,再池化啥特征都没有了。

https://blog.csdn.net/2301_80226956/article/details/159908855?spm=1001.2014.3001.5501

python 复制代码
import torch as th
import torch.nn as nn
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class CustomCNN(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
        super().__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=(3,3), stride=(1,1), padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=(3,3), stride=(1,1), padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3,3), stride=(1,1), padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=128),
    net_arch=dict(pi=[32, 32], vf=[64, 64]),
    activation_fn=th.nn.ReLU,
)

这是我们之前讲的,如何自定义特征提取器的一节,我们参考。

首先把依赖库补上

python 复制代码
# 1. 导入依赖库
from stable_baselines3 import PPO
import os
import time
from StarCraft2Env import StarCraft2Env

import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

然后参考自定义特征提取器的代码构建如下架构:

python 复制代码
class StarCraftCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space, features_dim=256):
        super().__init__(observation_space, features_dim)
        
        # 🔥 关键:无池化!小卷积!不丢失像素!
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            
            nn.Flatten()
        )
        # 计算最终特征维度
        with th.no_grad():
            n_flatten = self.cnn(th.rand(1, 3, 244, 244)).shape[1]
        
        self.linear = nn.Sequential(
            nn.Linear(n_flatten, 512),
            nn.ReLU(),
            nn.Linear(512, features_dim)
        )

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # (H,W,C) → (C,H,W)
        x = self.cnn(x)
        return self.linear(x)

然后仔创建模型实例

python 复制代码
model = PPO(
    policy="CnnPolicy",
    env=env,
    verbose=1,
    tensorboard_log=logs_dir,
    policy_kwargs=dict(
        features_extractor_class=StarCraftCNN,
        features_extractor_kwargs=dict(features_dim=256)
    )
)

可以输入这行指令,查看当前自定义的模型架构:

python 复制代码
model.policy

相关推荐
翔云1234562 小时前
大模型部署全流程深度解析
人工智能·ai·大模型
BU摆烂会噶2 小时前
【LangGraph】持久化实现的三大能力——人机交互
数据库·人工智能·python·langchain·人机交互
沐风老师3 小时前
开发AI机器人操作系统用什么编程语言?
人工智能·ai编程·机器人操作系统
念威3 小时前
弹幕互动游戏AI无人直播方案 - 可遇AI无人直播助手
人工智能·游戏
BizViewStudio3 小时前
甄选方法:2026 企业新媒体代运营的短视频精细化运营与流量转化技巧
大数据·网络·人工智能·媒体
咖啡星人k3 小时前
Vibe Coding 实践观察:从概念到云端开发工具的探索
人工智能
qq_283720053 小时前
Python+LangChain 入门到实战全教程+ 企业级案例
人工智能·langchain·#大模型·#llm·#rag·#ai 应用开发·#智能体
码点滴3 小时前
DeepSeek-V4 全景地图:两款模型、三种模式,你该怎么选?
人工智能·架构·大模型·deepseek-v4
Vane13 小时前
前端引擎开发记录
人工智能