PyTorch强化学习实战------Atari游戏包装器
-
- [0. 前言](#0. 前言)
- [1. SB3 代码库](#1. SB3 代码库)
- [2. stable-baseline3 包装器](#2. stable-baseline3 包装器)
- 相关链接
0. 前言
从资源角度来看,用强化学习 (Reinforcement Learning, RL)处理Atari 游戏颇具挑战性。为提升效率,需要对 Atari 游戏交互应用多种变换。其中部分变换仅影响性能,另一些则针对 Atari 平台特性------这些特性会导致学习过程漫长且不稳定。这些变换可以通过 Gymnasium 库的各类包装器实现,其中stable-baselines3 (SB3) 代码库是最常用的包装器之一。
1. SB3 代码库
SB3 包含大量基于 PyTorch 实现的 RL 方法,旨在成为比较各类算法的统一基准。目前,我们并不关心这些方法的实现(后续我们将自行重写大部分算法),但其中部分封装器极具实用价值。该代码库可在https://github.com/DLR-RM/stable-baselines3获取,可以使用 pip install stable-baselines3 进行安装。常用的 Atari 变换包括:
- 将游戏中的每条生命转为独立回合:通常一个回合包含从游戏开始到 "Game Over" 画面的所有步骤,可能长达数千个游戏步(观察与动作)。街机游戏通常给予玩家多条生命,这种变换将完整回合按生命数拆分为多个小回合。其内部通过检测模拟器的剩余生命数实现(尽管并非所有游戏都支持此功能,但
Pong可以)。这种方式能加速收敛,因为回合长度被缩短。SB3代码中的EpisodicLifeEnv包装器实现了该逻辑 - 游戏开始时执行随机次(最多
30次)空操作 ("no-op"):用于跳过一些Atari游戏中与玩法无关的片头介绍画面,由NoopResetEnv包装器实现 - 每
K(通常为3或4)步做一次动作决策:在中间帧重复执行选定动作。由于神经网络处理每帧的计算量很大,而连续帧差异通常微小,这种方式能显著加速训练。MaxAndSkipEnv包装器实现了此功能,该封装器还包含了下一个变换(即两帧间取最大值) - 对最近两帧的每个像素取最大值作为观测:一些
Atari游戏因平台限制存在画面闪烁现象。这种快速变化对人眼不可见,但会干扰神经网络判断 - 游戏开始时按下
FIRE键:某些游戏(包括Pong和Breakout)要求用户按下FIRE键才能开始游戏。若不执行此操作,环境就会变成部分可观测马尔可夫决策过程 (Partially Observable Markov Decision Process,POMDP),因为智能体无法通过观测判断是否已按下FIRE键。该功能由FireResetEnv包装器实现 - 将每帧图像从
210×160三色帧压缩为84×84单色图:有多种方法可以实现。例如DeepMind论文描述的方法是提取YCbCr色彩空间的Y通道,再将完整图像缩放至84×84分辨率。也可以先进行灰度转换,裁剪图像无关区域后再缩放。SB3代码库采用后一种方案,通过WarpFrame包装器实现 - 堆叠连续多帧(通常为四帧)以提供游戏动态信息:用于弥补单个游戏帧中缺乏动态信息的问题。
SB3项目未提供现成包装器,我们在wrappers.BufferWrapper中实现了自定义版本 - 将奖励值裁剪为
-1/0/1:不同游戏的得分差异悬殊。例如Pong每让球越过对手球拍得1分,而KungFuMaster每击败一个敌人可得100分。奖励值的巨大差异会导致损失函数在不同游戏间尺度迥异,难以找到通用超参数。通过ClipRewardEnv包装器将奖励值限制在-1到1范围内可解决此问题 - 调整观测维度以适应
PyTorch卷积层:由于要使用卷积运算,需按PyTorch要求重组张量维度。Atari环境返回(height, width, color)格式的观测,而PyTorch卷积层要求通道维度在前。该功能由wrappers.ImageToPyTorch实现
stable-baseline3 库已实现大部分包装器,其提供的 AtariWrapper 类能根据构造函数参数按需应用这些包装器,并能自动检测环境特性决定是否启用 FireResetEnv。Pong 游戏虽不需要所有包装器,但了解这些包装器对解决其他游戏很有必要。有时深度Q网络 (Deep Q-learning, DQN)无法收敛的问题并非代码错误,而是错误包装环境所致。
2. stable-baseline3 包装器
(1) 接下来,我们具体分析 stable-baseline3 提供的包装器实现:
python
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(1)
if terminated or truncated:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(2)
if terminated or truncated:
self.env.reset(**kwargs)
return obs, {}
上述包装器会在需要按 FIRE 键启动的游戏环境中自动执行该操作。除了触发FIRE键外,该包装器还会处理某些游戏中存在的特殊边界情况。
(2) MaxAndSkipEnv 包装器整合了动作重复执行 K 帧( K 帧动作重复)与连续两帧像素取最大值(双帧像素取最大)的功能:
python
class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
def __init__(self, env: gym.Env, skip: int = 4) -> None:
super().__init__(env)
# most recent raw observations (for max pooling across time steps)
assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
assert env.observation_space.shape is not None, "No shape defined for the observation space"
self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
self._skip = skip
def step(self, action: int) -> AtariStepReturn:
total_reward = 0.0
terminated = truncated = False
for i in range(self._skip):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += float(reward)
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, terminated, truncated, info
(3) WarpFrame 包装器的目标是将模拟器输出的 210×160 像素 RGB 彩色观测帧,转换为 84×84 灰度图像。其通过调用OpenCV 库的 cvtColor 函数实现------该函数采用符合人眼感知的色度学灰度转换算法(比简单的颜色通道平均值更接近人类视觉感知),随后对图像进行尺寸调整:
python
class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]):
def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
super().__init__(env)
self.width = width
self.height = height
assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}"
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(self.height, self.width, 1),
dtype=env.observation_space.dtype, # type: ignore[arg-type]
)
def observation(self, frame: np.ndarray) -> np.ndarray:
assert cv2 is not None, "OpenCV is not installed, you can do `pip install opencv-python`"
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
return frame[:, :, None]
(3) 以上介绍的包装器均来自 stable-baseline3,我们可以在 stable_baselines3/common/atari_wrappers.py 中找到其他可用包装器的代码。接下来,查看 wrappers.py 中的两个包装器实现:
python
import typing as tt
import gymnasium as gym
from gymnasium import spaces
import collections
import numpy as np
from stable_baselines3.common import atari_wrappers
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps):
super(BufferWrapper, self).__init__(env)
obs = env.observation_space
assert isinstance(obs, spaces.Box)
new_obs = gym.spaces.Box(
obs.low.repeat(n_steps, axis=0), obs.high.repeat(n_steps, axis=0),
dtype=obs.dtype)
self.observation_space = new_obs
self.buffer = collections.deque(maxlen=n_steps)
def reset(self, *, seed: tt.Optional[int] = None, options: tt.Optional[dict[str, tt.Any]] = None):
for _ in range(self.buffer.maxlen-1):
self.buffer.append(self.env.observation_space.low)
obs, extra = self.env.reset()
return self.observation(obs), extra
def observation(self, observation: np.ndarray) -> np.ndarray:
self.buffer.append(observation)
return np.concatenate(self.buffer)
BufferWrapper 类通过 deque 队列沿着第一维度堆叠连续帧,并将其作为观测返回。其核心目的是让神经网络感知物体动态特征,例如 Pong 中球的运动速度和方向,或是敌人的移动轨迹。这些关键动态信息无法从单帧图像中获取。
BufferWrapper 包装器的 observation 方法返回的是缓冲观测的副本,由于这些观测将被存入经验回放池,副本机制能避免后续环境步骤修改缓冲数据。理论上可以通过保存每回合的观测数据及索引来避免复制(从而将内存占用降低至四分之一),但这会大幅增加数据结构管理的复杂度。当前需要特别注意的是,该包装器必须置于环境包装器链的末端。
(4) ImageToPyTorch 包装器会将观测张量从 (height, width, channel) 的 HWC 格式,转换为 PyTorch 要求的 (channel, height, width) 的 CHW 格式:
python
class ImageToPyTorch(gym.ObservationWrapper):
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
obs = self.observation_space
assert isinstance(obs, gym.spaces.Box)
assert len(obs.shape) == 3
new_shape = (obs.shape[-1], obs.shape[0], obs.shape[1])
self.observation_space = gym.spaces.Box(
low=obs.low.min(), high=obs.high.max(),
shape=new_shape, dtype=obs.dtype)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
原始张量的色彩通道位于最后一维,但 PyTorch 卷积层要求色彩通道置于第一维。
(5) make_env() 函数能根据环境名称创建对应环境,并自动加载所有必要包装器:
python
def make_env(env_name: str, **kwargs):
env = gym.make(env_name, **kwargs)
env = atari_wrappers.AtariWrapper(env, clip_reward=False, noop_max=0)
env = ImageToPyTorch(env)
env = BufferWrapper(env, n_steps=4)
return env
我们使用了 stable-baseline3 中的 AtariWrapper 类,并禁用了部分非必要封装器。
相关链接
PyTorch强化学习实战(1)------强化学习(Reinforcement Learning,RL)详解
PyTorch强化学习实战(2)------强化学习环境库Gymnasium
PyTorch强化学习实战(3)------Gymnasium API扩展功能
PyTorch强化学习实战(4)------PyTorch基础
PyTorch强化学习实战(5)------PyTorch Ignite 事件驱动机制与实践
PyTorch强化学习实战(6)------交叉熵方法详解与实现
PyTorch强化学习实战(7)------表格学习与贝尔曼方程
PyTorch强化学习实战(8)------Q学习详解与实现
PyTorch强化学习实战(9)------深度Q学习