PyTorch强化学习实战——Atari游戏包装器

PyTorch强化学习实战------Atari游戏包装器

    • [0. 前言](#0. 前言)
    • [1. SB3 代码库](#1. SB3 代码库)
    • [2. stable-baseline3 包装器](#2. stable-baseline3 包装器)
    • 相关链接

0. 前言

从资源角度来看,用强化学习 (Reinforcement Learning, RL)处理Atari 游戏颇具挑战性。为提升效率,需要对 Atari 游戏交互应用多种变换。其中部分变换仅影响性能,另一些则针对 Atari 平台特性------这些特性会导致学习过程漫长且不稳定。这些变换可以通过 Gymnasium 库的各类包装器实现,其中stable-baselines3 (SB3) 代码库是最常用的包装器之一。

1. SB3 代码库

SB3 包含大量基于 PyTorch 实现的 RL 方法,旨在成为比较各类算法的统一基准。目前,我们并不关心这些方法的实现(后续我们将自行重写大部分算法),但其中部分封装器极具实用价值。该代码库可在https://github.com/DLR-RM/stable-baselines3获取,可以使用 pip install stable-baselines3 进行安装。常用的 Atari 变换包括:

  • 将游戏中的每条生命转为独立回合:通常一个回合包含从游戏开始到 "Game Over" 画面的所有步骤,可能长达数千个游戏步(观察与动作)。街机游戏通常给予玩家多条生命,这种变换将完整回合按生命数拆分为多个小回合。其内部通过检测模拟器的剩余生命数实现(尽管并非所有游戏都支持此功能,但 Pong 可以)。这种方式能加速收敛,因为回合长度被缩短。SB3 代码中的 EpisodicLifeEnv 包装器实现了该逻辑
  • 游戏开始时执行随机次(最多 30 次)空操作 ("no-op"):用于跳过一些 Atari 游戏中与玩法无关的片头介绍画面,由 NoopResetEnv 包装器实现
  • K (通常为 34 )步做一次动作决策:在中间帧重复执行选定动作。由于神经网络处理每帧的计算量很大,而连续帧差异通常微小,这种方式能显著加速训练。MaxAndSkipEnv 包装器实现了此功能,该封装器还包含了下一个变换(即两帧间取最大值)
  • 对最近两帧的每个像素取最大值作为观测:一些 Atari 游戏因平台限制存在画面闪烁现象。这种快速变化对人眼不可见,但会干扰神经网络判断
  • 游戏开始时按下 FIRE 键:某些游戏(包括 PongBreakout )要求用户按下 FIRE 键才能开始游戏。若不执行此操作,环境就会变成部分可观测马尔可夫决策过程 (Partially Observable Markov Decision Process, POMDP),因为智能体无法通过观测判断是否已按下 FIRE 键。该功能由 FireResetEnv 包装器实现
  • 将每帧图像从 210×160 三色帧压缩为 84×84 单色图:有多种方法可以实现。例如 DeepMind 论文描述的方法是提取 YCbCr 色彩空间的 Y 通道,再将完整图像缩放至 84×84 分辨率。也可以先进行灰度转换,裁剪图像无关区域后再缩放。SB3 代码库采用后一种方案,通过 WarpFrame 包装器实现
  • 堆叠连续多帧(通常为四帧)以提供游戏动态信息:用于弥补单个游戏帧中缺乏动态信息的问题。SB3 项目未提供现成包装器,我们在 wrappers.BufferWrapper 中实现了自定义版本
  • 将奖励值裁剪为 -1 / 0 / 1:不同游戏的得分差异悬殊。例如 Pong 每让球越过对手球拍得 1 分,而 KungFuMaster 每击败一个敌人可得 100 分。奖励值的巨大差异会导致损失函数在不同游戏间尺度迥异,难以找到通用超参数。通过 ClipRewardEnv 包装器将奖励值限制在 -11 范围内可解决此问题
  • 调整观测维度以适应 PyTorch 卷积层:由于要使用卷积运算,需按 PyTorch 要求重组张量维度。Atari 环境返回 (height, width, color) 格式的观测,而 PyTorch 卷积层要求通道维度在前。该功能由 wrappers.ImageToPyTorch 实现

stable-baseline3 库已实现大部分包装器,其提供的 AtariWrapper 类能根据构造函数参数按需应用这些包装器,并能自动检测环境特性决定是否启用 FireResetEnvPong 游戏虽不需要所有包装器,但了解这些包装器对解决其他游戏很有必要。有时深度Q网络 (Deep Q-learning, DQN)无法收敛的问题并非代码错误,而是错误包装环境所致。

2. stable-baseline3 包装器

(1) 接下来,我们具体分析 stable-baseline3 提供的包装器实现:

python 复制代码
from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn

class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
    def __init__(self, env: gym.Env) -> None:
        super().__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == "FIRE"  # type: ignore[attr-defined]
        assert len(env.unwrapped.get_action_meanings()) >= 3  # type: ignore[attr-defined]

    def reset(self, **kwargs) -> AtariResetReturn:
        self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated:
            self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(2)
        if terminated or truncated:
            self.env.reset(**kwargs)
        return obs, {}

上述包装器会在需要按 FIRE 键启动的游戏环境中自动执行该操作。除了触发FIRE键外,该包装器还会处理某些游戏中存在的特殊边界情况。

(2) MaxAndSkipEnv 包装器整合了动作重复执行 K 帧( K 帧动作重复)与连续两帧像素取最大值(双帧像素取最大)的功能:

python 复制代码
class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
    def __init__(self, env: gym.Env, skip: int = 4) -> None:
        super().__init__(env)
        # most recent raw observations (for max pooling across time steps)
        assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
        assert env.observation_space.shape is not None, "No shape defined for the observation space"
        self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
        self._skip = skip

    def step(self, action: int) -> AtariStepReturn:
        total_reward = 0.0
        terminated = truncated = False
        for i in range(self._skip):
            obs, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += float(reward)
            if done:
                break
        # Note that the observation on the done=True frame
        # doesn't matter
        max_frame = self._obs_buffer.max(axis=0)

        return max_frame, total_reward, terminated, truncated, info

(3) WarpFrame 包装器的目标是将模拟器输出的 210×160 像素 RGB 彩色观测帧,转换为 84×84 灰度图像。其通过调用OpenCV 库cvtColor 函数实现------该函数采用符合人眼感知的色度学灰度转换算法(比简单的颜色通道平均值更接近人类视觉感知),随后对图像进行尺寸调整:

python 复制代码
class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]):
    def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None:
        super().__init__(env)
        self.width = width
        self.height = height
        assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}"

        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.height, self.width, 1),
            dtype=env.observation_space.dtype,  # type: ignore[arg-type]
        )

    def observation(self, frame: np.ndarray) -> np.ndarray:
        assert cv2 is not None, "OpenCV is not installed, you can do `pip install opencv-python`"
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
        return frame[:, :, None]

(3) 以上介绍的包装器均来自 stable-baseline3,我们可以在 stable_baselines3/common/atari_wrappers.py 中找到其他可用包装器的代码。接下来,查看 wrappers.py 中的两个包装器实现:

python 复制代码
import typing as tt
import gymnasium as gym
from gymnasium import spaces
import collections
import numpy as np
from stable_baselines3.common import atari_wrappers

class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps):
        super(BufferWrapper, self).__init__(env)
        obs = env.observation_space
        assert isinstance(obs, spaces.Box)
        new_obs = gym.spaces.Box(
            obs.low.repeat(n_steps, axis=0), obs.high.repeat(n_steps, axis=0),
            dtype=obs.dtype)
        self.observation_space = new_obs
        self.buffer = collections.deque(maxlen=n_steps)

    def reset(self, *, seed: tt.Optional[int] = None, options: tt.Optional[dict[str, tt.Any]] = None):
        for _ in range(self.buffer.maxlen-1):
            self.buffer.append(self.env.observation_space.low)
        obs, extra = self.env.reset()
        return self.observation(obs), extra

    def observation(self, observation: np.ndarray) -> np.ndarray:
        self.buffer.append(observation)
        return np.concatenate(self.buffer)

BufferWrapper 类通过 deque 队列沿着第一维度堆叠连续帧,并将其作为观测返回。其核心目的是让神经网络感知物体动态特征,例如 Pong 中球的运动速度和方向,或是敌人的移动轨迹。这些关键动态信息无法从单帧图像中获取。
BufferWrapper 包装器的 observation 方法返回的是缓冲观测的副本,由于这些观测将被存入经验回放池,副本机制能避免后续环境步骤修改缓冲数据。理论上可以通过保存每回合的观测数据及索引来避免复制(从而将内存占用降低至四分之一),但这会大幅增加数据结构管理的复杂度。当前需要特别注意的是,该包装器必须置于环境包装器链的末端。

(4) ImageToPyTorch 包装器会将观测张量从 (height, width, channel)HWC 格式,转换为 PyTorch 要求的 (channel, height, width)CHW 格式:

python 复制代码
class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        obs = self.observation_space
        assert isinstance(obs, gym.spaces.Box)
        assert len(obs.shape) == 3
        new_shape = (obs.shape[-1], obs.shape[0], obs.shape[1])
        self.observation_space = gym.spaces.Box(
            low=obs.low.min(), high=obs.high.max(),
            shape=new_shape, dtype=obs.dtype)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)

原始张量的色彩通道位于最后一维,但 PyTorch 卷积层要求色彩通道置于第一维。

(5) make_env() 函数能根据环境名称创建对应环境,并自动加载所有必要包装器:

python 复制代码
def make_env(env_name: str, **kwargs):
    env = gym.make(env_name, **kwargs)
    env = atari_wrappers.AtariWrapper(env, clip_reward=False, noop_max=0)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, n_steps=4)
    return env

我们使用了 stable-baseline3 中的 AtariWrapper 类,并禁用了部分非必要封装器。

相关链接

PyTorch强化学习实战(1)------强化学习(Reinforcement Learning,RL)详解
PyTorch强化学习实战(2)------强化学习环境库Gymnasium
PyTorch强化学习实战(3)------Gymnasium API扩展功能
PyTorch强化学习实战(4)------PyTorch基础
PyTorch强化学习实战(5)------PyTorch Ignite 事件驱动机制与实践
PyTorch强化学习实战(6)------交叉熵方法详解与实现
PyTorch强化学习实战(7)------表格学习与贝尔曼方程
PyTorch强化学习实战(8)------Q学习详解与实现
PyTorch强化学习实战(9)------深度Q学习

相关推荐
viperrrrrrrrrr74 小时前
强化学习入门笔记
人工智能·强化学习
子榆.4 小时前
CANN PyTorch适配器深度拆解:从.cuda()到.npu()到底发生了什么
人工智能·pytorch·python
renke33644 小时前
写给前端的 CANN-torchtitan-npu:昇腾PyTorch Titan适配到底是啥?
前端·人工智能·pytorch·cann
多年小白4 小时前
今日A股 拉
大数据·人工智能·深度学习·microsoft·ai
初心未改HD14 小时前
深度学习之CNN卷积层详解
人工智能·深度学习·cnn
AI医影跨模态组学14 小时前
EBioMedicine美国佐治亚理工学院与埃默里大学:基于深度学习的放射组学与病理学多模态融合预测HPV相关口咽鳞状细胞癌预后
人工智能·深度学习·论文·医学·医学影像·影像组学
人工智能培训15 小时前
大模型与传统小模型、传统NLP模型的核心差异解析
人工智能·深度学习·神经网络·机器学习·生成对抗网络
Terrence Shen18 小时前
大模型部署工具对比
人工智能·深度学习·计算机视觉