编写自定义环境
我们已经写完下位机的脚本了,现在回过头来继续写上位机的内容。
还记得gym的环境要自实现四个函数step() render() close() reset()
step函数的编写
我们刚刚是从step()跳出来的,因此要回去写完step()
step()需要返回这些内容,因此我们需要在上位机把transaction的对应值return出来。
|------------------------------|------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| .step(action) (执行动作) | 作用 :这是强化学习的核心。将 Agent 的动作(Action)传入环境,环境反馈下一步的状态和奖励。 参数 :action (Agent 选择的动作)。 | 1. observation :执行动作后的新环境状态。 2. reward :该动作获得的奖励(浮点数)。 3. terminated :布尔值。True 表示回合正常结束(如到达目标/坠毁)。 4. truncated :布尔值。True 表示回合被强制截断(如超时/出界)。 5. info:辅助诊断信息。 |
python
def step(self,action):
while True:
try:
with open('transaction.pkl','rb') as f:#先读
transaction=pickle.load(f)
if transaction['action'] is None:#如果没有记录
transaction['action']=action
with open('transaction.pkl','wb') as f:
pickle.dump(transaction,f)
break
except Exception as e:
time.sleep(0.05)
while True:
try:
with open('transaction.pkl', 'rb') as f:
transaction = pickle.load(f)
if transaction['action'] is None:
break
except:
time.sleep(0.05)
回顾一下,step的逻辑,首先先看看transaction.pkl是否有记录,如果没有记录,就写入action,然后交给下位机读取,下位机读取后会修改action为None。因此上位机在提交action后立刻break,进入第二个while阻塞态,当发现transaction.pkl的action被下位机赋None后break,进入下一步,重新进入第一个while。
那么我们要获取observation\reward这类信息,就需要在第二个while中接收。
python
def step(self,action):
while True:
try:
with open('transaction.pkl','rb') as f:#先读
transaction=pickle.load(f)
if transaction['action'] is None:#如果没有记录
transaction['action']=action
with open('transaction.pkl','wb') as f:
pickle.dump(transaction,f)
break
except Exception as e:
time.sleep(0.05)
while True:
try:
with open('transaction.pkl', 'rb') as f:
transaction = pickle.load(f)
if transaction['action'] is None:
observation=transaction['observation']
reward=transaction['reward']
terminated=transaction['terminated']
truncated=transaction['truncated']
break
except:
time.sleep(0.05)
info={}
return observation,reward,terminated,truncated,info
reset函数的编写
下位机的代码需要由上位机启动,相当于是一个子线程,因此需要导入子线程库,并在reset中启动下位机脚本。
python
import subprocess
然后在reset函数中启动子线程:
切记要用shell=True否则无法打开星际争霸的窗口。
python
def reset(self,seed=None, options=None):
print('reset the Env')
map=np.zeros((244,244,3),dtype=np.uint8)
observation=map
transaction={'observation':map,'reward':0,'action':None,'terminated':False,'truncated':False}
with open('transaction.pkl','wb') as f:
pickle.dump(transaction,f)
#subprocess.Popen(['Python3','WorkerRushBot.py'],creationflags=subprocess.CREATE_NEW_CONSOLE)
subprocess.Popen(
[
"cmd", "/c", "start",
"python", "WorkerRushBot.py"
],
shell=True
)
print('clear')
info={}
return observation,info
为了验证我们的环境是否符合Gym规范,在SB3官网,我们可以使用check_env验证。
https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html
python
from stable_baselines3.common.env_checker import check_env
check_env(sc2)
测试没报错就删掉这行,否则会开多个shell。之后用reset调用。
封装脚本
把自定义环境的代码封装成"StarCraft2Env.py"的代码,路径和当前jupyter文件一致。
代码有所改动,完整文件在文末获取。

强化学习PPO环境的编写
我们已经完成自定义环境【上位机】和操作脚本【下位机】的编写了,接下来到了最后一步,如何利用SB3套用自定义的环境来训练。
我们在jupyter再创建一个SC2_Training的文档。

我们计划用PPO来训练,这是一种基于优化策略的模型,具体原理可以在我之前的强化学习理论找到。
可以在SB3官网找到PPO的文档
https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
我们的训练需要基于刚刚保存的Starcraft2Env,因此要把这个py文件作为库导入
python
# 1. 导入依赖库
from stable_baselines3 import PPO
import os
import time
from StarCraft2Env import StarCraft2Env
由于我们使用的是图像输入,训练量也非常大,中途可能出现断电等情况。因此我们需要每隔一定batch保存一个模型(也可以使用callbacks,往期讲过就不再提了)
python
model_name = f'{int(time.time())}'
model_dir = f'models/{model_name}/'
logs_dir = f'logs/{model_name}/'
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(logs_dir):
os.makedirs(logs_dir)
- 核心设计 :用时间戳作为模型名,保证每次训练的模型 / 日志目录唯一,不会覆盖历史训练数据
model_dir:保存训练好的 PPO 模型权重,用于后续加载、测试、迭代logs_dir:保存 TensorBoard 训练日志,用于可视化训练曲线(奖励、胜率、损失等)- 自动创建目录,避免因目录不存在导致的报错
log是用于在tensorboard查看训练曲线的,不赘述。
然后要选PPO的策略了,我们选CNN,老师选了MLP,我不理解。
我们就用CNN,图像输入不用CNN难到用全连接?244*244*3都多大了,训个der?
然后CNN不能直接用,因为默认的cnnpolicy特征提取器是做了池化的,池化就是区域取平均/最值,严重损失细节,就这么小一张图,单位本来就小,再池化啥特征都没有了。
https://blog.csdn.net/2301_80226956/article/details/159908855?spm=1001.2014.3001.5501
python
import torch as th
import torch.nn as nn
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
class CustomCNN(BaseFeaturesExtractor):
"""
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: spaces.Box, features_dim: int = 256):
super().__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=(3,3), stride=(1,1), padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=(3,3), stride=(1,1), padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=(3,3), stride=(1,1), padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(
th.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=128),
net_arch=dict(pi=[32, 32], vf=[64, 64]),
activation_fn=th.nn.ReLU,
)
这是我们之前讲的,如何自定义特征提取器的一节,我们参考。
首先把依赖库补上
python
# 1. 导入依赖库
from stable_baselines3 import PPO
import os
import time
from StarCraft2Env import StarCraft2Env
import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
然后参考自定义特征提取器的代码构建如下架构:
python
class StarCraftCNN(BaseFeaturesExtractor):
def __init__(self, observation_space, features_dim=256):
super().__init__(observation_space, features_dim)
# 🔥 关键:无池化!小卷积!不丢失像素!
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Flatten()
)
# 计算最终特征维度
with th.no_grad():
n_flatten = self.cnn(th.rand(1, 3, 244, 244)).shape[1]
self.linear = nn.Sequential(
nn.Linear(n_flatten, 512),
nn.ReLU(),
nn.Linear(512, features_dim)
)
def forward(self, x):
x = x.permute(0, 3, 1, 2) # (H,W,C) → (C,H,W)
x = self.cnn(x)
return self.linear(x)
然后仔创建模型实例
python
model = PPO(
policy="CnnPolicy",
env=env,
verbose=1,
tensorboard_log=logs_dir,
policy_kwargs=dict(
features_extractor_class=StarCraftCNN,
features_extractor_kwargs=dict(features_dim=256)
)
)
可以输入这行指令,查看当前自定义的模型架构:
python
model.policy

