强化学习[page14]【chapter7】Temporal-Difference Learning (TD learning)

时间差分学习是蒙特卡罗(MC)方法和动态规划(DP)方法的结合:

  • 与 MC 方法类似,TD 方法可以从经验中学习,而无需 环境动力学模型。
  • 与 DP 方法类似,TD 方法在每一步之后都会根据其他学习到的估计值 更新估计值, 而无需等待结果 (这被称为引导)。

TD 方法的一个特点是,它在每个时间步更新其值估计,而 MC 方法则要等到回合结束才更新。

实际上,这两种方法的更新目标不同 。MC 方法旨在更新收益Gt,而 Gt 仅在每轮迭代结束时可用。相比之下,TD 方法的目标是:

其中V真实值函数 Vπ 的估计值

因此,TD 方法结合了 MC 的抽样 通过使用真实值的估计)和DP自助法(通过基于依赖于进一步估计的估计来更新 V)。

最后给出一个TD0 实现世界环境探索的例子


目录

  1. TD 算法例子
  2. TD 收敛性定理
  3. TD 和 MC 的比较
  4. TD0 Python 例子(TD0 只有策略评估)

TD 算法例子

如下机器人走迷宫的例子:

环境变量:初始状态值

1: 当Robot 从 移动到,我们有

TD target :

TD error:

state value

2: 当robot 从 移动到,我们有

TD target :

TD error:

state value

3: 按照上面的步骤反复迭代更新state value


二 TD 收敛性定理

2.1 说明

TD 算法实质上是求解给定策略的Bellman 公式

其中G 是 discount return

其中代表 next state,公式(4)可以写作

公式5也称为 Bellman equation , 也称为 Bellman expectation equation

前面一章知道该公式可以RM算法求解

2.2 定理解读

  • 结论 :在满足特定条件下,TD算法(公式7.1,本质上是TD(0))几乎必然(almost surely) 收敛到真实的状态价值函数。这意味着随着时间推移,我们的估计将以概率1无限逼近真值。(证明略)

  • 关键条件:收敛需要两个关于学习率 αt​(s) 的经典条件:

    :保证有"足够长"的学习时间,最终能克服初始误差和随机性。:保证学习步长最终会变得足够小,以平息更新中的随机噪声,使估计能稳定下来。

  • 与状态的关联 :注意学习率是与状态相关 。条件要求对每个状态 s 都成立,这引出了下一个关键点。

这部分连接了理论条件和实际应用。

  • "无限次访问"要求 :因为 αt(s)>0 仅当状态 s 在时刻 t 被访问,所以条件 等价于要求 状态 s 在无限的时间步中被访问无限次

    • 理论保障 :这解释了为什么强化学习需要充分探索 。无论通过 探索性起点 还是采用随机性策略(如 ε-greedy),目标都是确保所有状态能被无限次访问,避免因"没见过"而无法学习。
  • 常数学习率的妥协 :在实践中,常使用小的常数学习率 α ,但这违反了的条件(因为常数平方和发散)。

    • 理论与实践的桥梁 :文本指出,即使使用常数学习率,算法仍能收敛,但收敛形式变为 "在期望意义上收敛"。这意味着估计值会在真实值附近波动,而波动的期望为零。这是一种更弱但更实用的收敛保证。

三 TD 和 MC 的比较

1. Incremental vs. Non-incremental:
  • TD(Incremental ): 在每一步 之后都可以立即更新价值估计。这使它适用于在线学习,即边交互边学习,比如自动驾驶汽车的实时学习。

  • MC(非增量): 必须等待整个episode(从开始到终止状态的一次完整经历)结束后,得到累计回报 Gt,才能进行更新。这意味着学习是离线的、批处理的。

2. Continuing tasks vs Episodic tasks:
  • TD(适用于Continuing tasks): 因其增量性,TD方法不依赖于"Episodic end "的概念。对于没有明确终点的连续任务(如长期运行的机器人控制),TD可以通过持续更新来学习。

  • MC(Episodic tasks:): 其核心依赖于计算从某个状态到Episodic end 的总回报。如果没有Episodic end ,Gt 就无法定义,因此MC无法直接应用于连续任务。

3. Bootstrapping vs Non-Bootstrapping:
  • TD(自举): TD的更新目标(如)依赖于当前的价值估计值。这就像"拽着自己的鞋带把自己拉起来"------用现有(可能不准确)的估计来更新自己。

    • 优点: 可以更快地利用已有知识,传播信息效率高。

    • 缺点: **对初始值敏感(**如果初始猜测很差,需要时间纠正),且最终估计是有偏的。

  • MC(非自举): MC的更新目标 Gt 是完全通过实际采样得到的回报,不依赖于任何价值函数的初始估计。

    • 优点: 估计最终是无偏的,收敛到的解不受初始值影响。

    • 缺点: 必须等幕结束,且初期信息传播慢。

4. Low estimation variance vs High estimation variance:

这是最关键、最实际的差异之一。

  • TD(Low estimation variance ): TD的更新只涉及少量随机变量。以Sarsa为例,更新 q(St, At) 只依赖于下一个即时奖励 Rt+1、下一个状态 St+1 和下一个动作 At+1。这些随机性的来源较少,因此单次更新的波动较小,学习过程更平稳。

  • MC(High estimation variance): MC的更新目标 Gt 是未来一长串随机变量(奖励、状态转移、动作选择)的累加和。每一个环节的随机性都会累积到最终回报中。如文中所述,一个长度为 L 的幕,在软策略下有 |A|^L 种可能路径。仅用少数几条路径的回报来估计价值,其波动性(方差)必然会非常高,导致学习不稳定、收敛慢。


TD0 Python 例子(TD0 只有策略评估)

本实验将使用一个网格世界环境,该环境具有以下特征:

  • 网格为 12×8 个单元格。
  • 智能体从网格的左下角开始,目标是到达位于右上角的宝藏(奖励为 1 的终止状态)。
  • 蓝色传送门是连通的,穿过位于单元格(6, 10)的传送门 会到达单元格(0, 11)。智能体在第一次传送后不能再次穿过该传送门。
  • 紫色传送门 仅在100集后出现,但可以让特工更快地找到宝藏。这鼓励玩家不断探索环境。
  • 红色传送门是陷阱(奖励为 0 的终止状态),会结束本集。
  • 撞到墙壁会导致智能体保持原状。

#配色方案:

开始位置 - 金色标记(珍贵起点) RGB: (255, 215, 0)

目标位置 - 安全绿(安心到达) RGB: (50, 205, 50)

陷阱位置 - 警示红(危险警告) RGB: (220, 20, 60)

蓝色传送门 - 科技蓝(稳定通道) RGB: (0, 191, 255)

紫色传送门 - 神秘紫(随机通道) RGB: (147, 112, 219)

墙壁 - 暗灰石墙(环境屏障) RGB: (105, 105, 105)

动作定义

环境说明

复制代码
# -*- coding: utf-8 -*-
"""
基于TD(0)算法的网格世界强化学习实现
Grid World Reinforcement Learning using TD(0) Algorithm

作者: chengxf2
日期: 2025年12月
描述: 实现TD(0)算法在复杂网格世界环境中的策略评估
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import Tuple, List, Dict, Any, Optional
import random
from enum import IntEnum


class Action(IntEnum):
    """动作枚举类"""
    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3

# 颜色定义常量
COLORS = {
        'start': (255, 215, 0),      # 金色 - 开始位置
        'goal': (50, 205, 50),       # 安全绿 - 目标位置
        'trap': (220, 20, 60),       # 警示红 - 陷阱位置
        'blue_portal': (0, 191, 255),  # 科技蓝 - 蓝色传送门
        'purple_portal': (147, 112, 219),  # 神秘紫 - 紫色传送门
        'wall': (105, 105, 105),     # 暗灰石墙 - 墙壁
}
    


class GridWorldEnvironment:
    """
    网格世界环境类
    
    8×12网格,包含传送门、陷阱和动态元素
    坐标系: (行, 列) 或 (x, y),其中x从上到下,y从左到右
    """
    

    def __init__(self) -> None:
        """初始化网格世界环境参数"""
        # 网格尺寸
        self.rows: int = 8
        self.cols: int = 12
        
        # 特殊位置定义
        self.start_position: Tuple[int, int] = (7, 0)  # 开始位置
        self.goal_position: Tuple[int, int] = (2, 9)   # 目标位置
        
        # 陷阱位置
        self.trap_positions: List[Tuple[int, int]] = [
            (0, 8), (1, 8), (2, 8), (3, 8),
            (3, 9), (3, 10), (3, 11)
        ]
        
        # 蓝色传送门
        self.blue_portal_entrance: Tuple[int, int] = (6, 10)
        self.blue_portal_exit: Tuple[int, int] = (0, 11)
        self.is_blue_portal_used: bool = False
        
        # 紫色传送门(100轮后激活)
        self.is_purple_portal_active: bool = False
        self.purple_portal_entrance: Tuple[int, int] = (2, 1)
        self.purple_portal_exit: List[Tuple[int, int]] = [
            (0, 9), (0, 10), (0, 11),
            (1, 9), (1, 10), (1, 11),
            (2, 9), (2, 10), (2, 11)
        ]
        
        # 墙壁位置
        self.wall_positions: List[Tuple[int, int]] = [
            (5, 3), (6, 3), (7, 3),  # 第一组障碍物
            (3, 0), (3, 1), (3, 2),  # 第二组障碍物
        ]
        
        # 当前状态和环境统计
        self.current_state: Tuple[int, int] = self.start_position
        self.total_episodes: int = 0
        self.step_count: int = 0
        
        # 动作映射
        self.action_mapping = {
            Action.UP: self._move_up,
            Action.DOWN: self._move_down,
            Action.LEFT: self._move_left,
            Action.RIGHT: self._move_right,
        }
    
    def reset(self) -> Tuple[int, int]:
        """
        重置环境到初始状态
        
        Returns:
            Tuple[int, int]: 初始状态坐标
        """
        self.current_state = self.start_position
        self.is_blue_portal_used = False
        self.step_count = 0
        return self.current_state
    
    def step(self, action: Action) -> Tuple[Tuple[int, int], float, bool]:
        """
        执行动作并返回新的状态、奖励和终止标志
        
        Args:
            action: 动作枚举
            
        Returns:
            Tuple[状态, 奖励, 是否终止]
        """
        current_x, current_y = self.current_state
        self.step_count += 1
        
        # 计算移动后的新位置
        new_x, new_y = self._calculate_new_position(current_x, current_y, action)
        
        # 检查墙壁碰撞
        if (new_x, new_y) in self.wall_positions:
            new_x, new_y = current_x, current_y
        
        # 检查陷阱
        if (new_x, new_y) in self.trap_positions:
            self.current_state = (new_x, new_y)
            return self.current_state, -1.0, True
        
        # 检查目标
        if (new_x, new_y) == self.goal_position:
            self.current_state = (new_x, new_y)
            return self.current_state, 1.0, True
        
        # 检查传送门
        new_x, new_y = self._process_portals(new_x, new_y)
        
        # 更新当前状态
        self.current_state = (new_x, new_y)
        
        # 默认奖励为0(稀疏奖励设置)
        return self.current_state, 0.0, False
    
    def _calculate_new_position(self, x: int, y: int, action: Action) -> Tuple[int, int]:
        """
        计算移动后的新位置
        
        Args:
            x: 当前x坐标
            y: 当前y坐标
            action: 动作
            
        Returns:
            Tuple[int, int]: 新位置坐标
        """
        move_func = self.action_mapping.get(action)
        if move_func is None:
            raise ValueError(f"无效动作: {action}")
        return move_func(x, y)
    
    def _move_up(self, x: int, y: int) -> Tuple[int, int]:
        """向上移动"""
        new_x = max(x - 1, 0)
        return new_x, y
    
    def _move_down(self, x: int, y: int) -> Tuple[int, int]:
        """向下移动"""
        new_x = min(x + 1, self.rows - 1)
        return new_x, y
    
    def _move_left(self, x: int, y: int) -> Tuple[int, int]:
        """向左移动"""
        new_y = max(y - 1, 0)
        return x, new_y
    
    def _move_right(self, x: int, y: int) -> Tuple[int, int]:
        """向右移动"""
        new_y = min(y + 1, self.cols - 1)
        return x, new_y
    
    def _process_portals(self, x: int, y: int) -> Tuple[int, int]:
        """
        检查并处理传送门
        
        Args:
            x: 当前x坐标
            y: 当前y坐标
            
        Returns:
            Tuple[int, int]: 处理后的位置
        """
        # 蓝色传送门(仅可使用一次)
        if (not self.is_blue_portal_used) and ((x, y) == self.blue_portal_entrance):
            self.is_blue_portal_used = True
            return self.blue_portal_exit
        
        # 紫色传送门(100轮后激活)
        elif self.is_purple_portal_active and ((x, y) == self.purple_portal_entrance):
            return random.choice(self.purple_portal_exit)
        
        # 不是传送门
        return (x, y)
    
    def step_simulate(self, state: Tuple[int, int], action: Action) -> Tuple[Tuple[int, int], float, bool]:
        """
        模拟执行动作而不改变环境状态(用于规划)
        
        Args:
            state: 当前状态
            action: 动作
            
        Returns:
            Tuple[下一状态, 奖励, 是否终止]
        """
        x, y = state
        
        # 计算新位置
        new_x, new_y = self._calculate_new_position(x, y, action)
        
        # 检查墙壁
        if (new_x, new_y) in self.wall_positions:
            new_x, new_y = x, y
        
        # 检查陷阱
        if (new_x, new_y) in self.trap_positions:
            return (new_x, new_y), -1.0, True
        
        # 检查目标
        if (new_x, new_y) == self.goal_position:
            return (new_x, new_y), 1.0, True
        
        # 检查传送门
        if (not self.is_blue_portal_used) and ((new_x, new_y) == self.blue_portal_entrance):
            new_x, new_y = self.blue_portal_exit
        
        if self.is_purple_portal_active and ((new_x, new_y) == self.purple_portal_entrance):
            new_x, new_y = random.choice(self.purple_portal_exit)
        
        # 再次检查是否撞墙(传送后可能撞墙)
        if (new_x, new_y) in self.wall_positions:
            new_x, new_y = x, y
        
        return (new_x, new_y), 0.0, False
    
    def update_portal_activation(self, episode_count: int) -> None:
        """
        根据训练轮数更新传送门激活状态
        
        Args:
            episode_count: 当前训练轮数
        """
        self.is_purple_portal_active = (episode_count >= 100)
    
    def render(self, figsize: Tuple[int, int] = (10, 8), show_legend: bool = True) -> None:
        """
        可视化环境状态
        
        Args:
            figsize: 图形大小
            show_legend: 是否显示图例
        """
        self._setup_matplotlib_fonts()
        
        # 创建图形和坐标轴
        fig, ax = plt.subplots(figsize=figsize)
        
        # 设置坐标轴
        self._setup_axes(ax)
        
        # 绘制网格和各个元素
        self._draw_grid_background(ax)
        self._draw_walls(ax)
        self._draw_start_position(ax)
        self._draw_goal_position(ax)
        self._draw_traps(ax)
        self._draw_blue_portal(ax)
        self._draw_purple_portal(ax)
        self._draw_agent(ax)
        self._draw_grid_coordinates(ax)
        
        # 添加标题和图例
        self._add_title_and_legend(ax, show_legend)
        
        plt.tight_layout()
        plt.show()
        
        # 打印文本版本
        self._print_text_version()
    
    def _setup_matplotlib_fonts(self) -> None:
        """设置matplotlib字体"""
        try:
            plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
        except Exception:
            pass
        plt.rcParams['axes.unicode_minus'] = False
    
    def _setup_axes(self, ax) -> None:
        """设置坐标轴"""
        ax.set_xlim(-0.5, self.cols - 0.5)
        ax.set_ylim(-0.5, self.rows - 0.5)
        ax.set_aspect('equal')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.invert_yaxis()  # 让y轴向下为正
    
    def _draw_grid_background(self, ax) -> None:
        """绘制网格背景"""
        for x in range(self.rows):
            for y in range(self.cols):
                rect = patches.Rectangle(
                    (y - 0.5, x - 0.5), 1, 1,
                    linewidth=1,
                    edgecolor=(0.78, 0.78, 0.78),  # 浅灰色
                    facecolor=(0.94, 0.94, 0.94),  # 更浅的灰色
                )
                ax.add_patch(rect)
    
    def _draw_walls(self, ax) -> None:
        """绘制墙壁"""
        for wall_x, wall_y in self.wall_positions:
            rect = patches.Rectangle(
                (wall_y - 0.5, wall_x - 0.5), 1, 1,
                linewidth=1,
                edgecolor=self._normalize_color(COLORS['wall']),
                facecolor=self._normalize_color(COLORS['wall']),
                hatch='////'
            )
            ax.add_patch(rect)
            ax.text(wall_y, wall_x, '█', ha='center', va='center', 
                   fontsize=14, color='white', fontweight='bold')
    
    def _draw_start_position(self, ax) -> None:
        """绘制开始位置"""
        start_x, start_y = self.start_position
        rect = patches.Rectangle(
            (start_y - 0.5, start_x - 0.5), 1, 1,
            linewidth=2,
            edgecolor='black',
            facecolor=self._normalize_color(COLORS['start'])
        )
        ax.add_patch(rect)
        ax.text(start_y, start_x, 'S', ha='center', va='center', 
               fontsize=12, color='black', fontweight='bold')
    
    def _draw_goal_position(self, ax) -> None:
        """绘制目标位置"""
        goal_x, goal_y = self.goal_position
        rect = patches.Rectangle(
            (goal_y - 0.5, goal_x - 0.5), 1, 1,
            linewidth=2,
            edgecolor='black',
            facecolor=self._normalize_color(COLORS['goal'])
        )
        ax.add_patch(rect)
        ax.text(goal_y, goal_x, 'G', ha='center', va='center', 
               fontsize=12, color='white', fontweight='bold')
    
    def _draw_traps(self, ax) -> None:
        """绘制陷阱位置"""
        for trap_x, trap_y in self.trap_positions:
            rect = patches.Rectangle(
                (trap_y - 0.5, trap_x - 0.5), 1, 1,
                linewidth=1,
                edgecolor=self._normalize_color(COLORS['trap']),
                facecolor=self._normalize_color(COLORS['trap'])
            )
            ax.add_patch(rect)
            ax.text(trap_y, trap_x, 'T', ha='center', va='center', 
                   fontsize=10, color='white', fontweight='bold')
    
    def _draw_blue_portal(self, ax) -> None:
        """绘制蓝色传送门"""
        # 入口
        blue_x, blue_y = self.blue_portal_entrance
        rect = patches.Rectangle(
            (blue_y - 0.5, blue_x - 0.5), 1, 1,
            linewidth=2,
            edgecolor='black',
            facecolor=self._normalize_color(COLORS['blue_portal']),
            alpha=0.8
        )
        ax.add_patch(rect)
        ax.text(blue_y, blue_x, 'B', ha='center', va='center', 
               fontsize=12, color='white', fontweight='bold')
        
        # 出口
        blue_exit_x, blue_exit_y = self.blue_portal_exit
        rect = patches.Rectangle(
            (blue_exit_y - 0.5, blue_exit_x - 0.5), 1, 1,
            linewidth=2,
            edgecolor=self._normalize_color(COLORS['blue_portal']),
            facecolor='white',
            linestyle='--',
            alpha=0.6
        )
        ax.add_patch(rect)
        ax.text(blue_exit_y, blue_exit_x, 'BE', ha='center', va='center', 
               fontsize=10, color=self._normalize_color(COLORS['blue_portal']), 
               fontweight='bold')
    
    def _draw_purple_portal(self, ax) -> None:
        """绘制紫色传送门"""
        if not self.is_purple_portal_active:
            return
            
        purple_x, purple_y = self.purple_portal_entrance
        rect = patches.Rectangle(
            (purple_y - 0.5, purple_x - 0.5), 1, 1,
            linewidth=2,
            edgecolor='black',
            facecolor=self._normalize_color(COLORS['purple_portal']),
            alpha=0.8
        )
        ax.add_patch(rect)
        ax.text(purple_y, purple_x, 'P', ha='center', va='center', 
               fontsize=12, color='white', fontweight='bold')
        
        # 出口区域
        for exit_x, exit_y in self.purple_portal_exit:
            rect = patches.Rectangle(
                (exit_y - 0.5, exit_x - 0.5), 1, 1,
                linewidth=1,
                edgecolor=self._normalize_color(COLORS['purple_portal']),
                facecolor='white',
                linestyle='--',
                alpha=0.4
            )
            ax.add_patch(rect)
    
    def _draw_agent(self, ax) -> None:
        """绘制智能体"""
        agent_x, agent_y = self.current_state
        circle = patches.Circle(
            (agent_y, agent_x), 0.3,
            linewidth=2,
            edgecolor='black',
            facecolor=(1.0, 0.65, 0.0)  # 橙色
        )
        ax.add_patch(circle)
        ax.text(agent_y, agent_x, 'A', ha='center', va='center', 
               fontsize=10, color='black', fontweight='bold')
    
    def _draw_grid_coordinates(self, ax) -> None:
        """绘制网格坐标"""
        for x in range(self.rows):
            for y in range(self.cols):
                cell = (x, y)
                if self._is_empty_cell(cell):
                    ax.text(y, x, f'({x},{y})', ha='center', va='center', 
                           fontsize=6, color='black', alpha=0.5)
    
    def _add_title_and_legend(self, ax, show_legend: bool) -> None:
        """添加标题和图例"""
        title = f"网格世界环境\n轮数: {self.total_episodes}, 步数: {self.step_count}"
        ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
        
        if show_legend:
            legend_elements = [
                patches.Patch(facecolor=self._normalize_color(COLORS['start']), 
                            edgecolor='black', label='起点 (Start)'),
                patches.Patch(facecolor=self._normalize_color(COLORS['goal']), 
                            edgecolor='black', label='目标 (Goal)'),
                patches.Patch(facecolor=self._normalize_color(COLORS['trap']), 
                            edgecolor='black', label='陷阱 (Trap)'),
                patches.Patch(facecolor=self._normalize_color(COLORS['blue_portal']), 
                            edgecolor='black', label='蓝色传送门'),
                patches.Patch(facecolor=self._normalize_color(COLORS['purple_portal']), 
                            edgecolor='black', label='紫色传送门'),
                patches.Patch(facecolor=self._normalize_color(COLORS['wall']), 
                            edgecolor='black', label='墙壁 (Wall)'),
                patches.Patch(facecolor=(1.0, 0.65, 0.0), 
                            edgecolor='black', label='智能体 (Agent)'),
            ]
            
            ax.legend(handles=legend_elements, 
                     loc='upper right', 
                     bbox_to_anchor=(1.15, 1),
                     fontsize=9,
                     title='图例说明',
                     title_fontsize=10)
        
        # 添加边框
        for spine in ax.spines.values():
            spine.set_linewidth(2)
            spine.set_color('black')
    
    def _print_text_version(self) -> None:
        """打印文本版本的环境状态"""
        print(f"\n当前环境状态 (轮数: {self.total_episodes}, 步数: {self.step_count})")
        
        if self.total_episodes >= 100:
            print("紫色传送门已激活")
        else:
            print("紫色传送门未激活")
    
    def _is_empty_cell(self, cell: Tuple[int, int]) -> bool:
        """判断是否为空白单元格"""
        x, y = cell
        return (cell not in self.wall_positions and 
                cell != self.start_position and 
                cell != self.goal_position and 
                cell not in self.trap_positions and
                cell != self.blue_portal_entrance and
                not (self.is_purple_portal_active and cell == self.purple_portal_entrance) and
                cell != self.current_state)
    
    @staticmethod
    def _normalize_color(rgb_tuple: Tuple[int, int, int]) -> Tuple[float, float, float]:
        """将RGB颜色从0-255范围归一化到0-1范围"""
        return tuple(c / 255.0 for c in rgb_tuple)
    
    def get_environment_summary(self) -> Dict[str, Any]:
        """获取环境摘要信息"""
        return {
            'dimensions': (self.rows, self.cols),
            'start_position': self.start_position,
            'goal_position': self.goal_position,
            'trap_count': len(self.trap_positions),
            'wall_count': len(self.wall_positions),
            'has_blue_portal': True,
            'has_purple_portal': True,
            'is_purple_portal_active': self.is_purple_portal_active
        }


class Policy:
    """策略基类"""
    
    def __init__(self, env: GridWorldEnvironment):
        self.env = env
    
    def get_action_probabilities(self, state: Tuple[int, int]) -> List[float]:
        """获取动作概率分布"""
        raise NotImplementedError
    
    def _normalize_probabilities(self, probs: List[float]) -> List[float]:
        """归一化概率分布,确保和为1"""
        if not probs:
            return [0.25, 0.25, 0.25, 0.25]
        
        total = sum(probs)
        
        # 如果概率和为0,返回均匀分布
        if total == 0:
            return [0.25, 0.25, 0.25, 0.25]
        
        # 归一化
        normalized = [p / total for p in probs]
        
        # 处理浮点误差
      
        
        return normalized
    
    def select_action(self, state: Tuple[int, int]) -> Action:
        """根据策略选择动作"""
        action_probs = self.get_action_probabilities(state)
        
        # 归一化概率
        normalized_probs = self._normalize_probabilities(action_probs)
        
        # 验证概率和是否为1(允许小的浮点误差)
        prob_sum = sum(normalized_probs)
        if abs(prob_sum - 1.0) > 1e-10:
    
            normalized_probs = [p / prob_sum for p in normalized_probs]
        
        return Action(np.random.choice(len(Action), p=normalized_probs))





class RightPreferentialPolicy(Policy):
    """向右偏好策略:向右动作概率最高"""
    
    def __init__(self, env: GridWorldEnvironment, right_prob: float = 0.7):
        super().__init__(env)
        if not 0 <= right_prob <= 1:
            raise ValueError(f"right_prob必须在0和1之间,当前为{right_prob}")
        self.right_prob = right_prob
        self.other_prob = (1.0 - right_prob) / 3.0
    
    def get_action_probabilities(self, state: Tuple[int, int]) -> List[float]:
        # 检查是否为终止状态
        if (state in self.env.wall_positions or
            state in self.env.trap_positions or
            state == self.env.goal_position):
            return [0.25, 0.25, 0.25, 0.25]
        
        # 检查每个动作是否可用
        action_availability = []
        for action in Action:
            next_state, _, _ = self.env.step_simulate(state, action)
            is_available = next_state not in self.env.wall_positions
            action_availability.append(is_available)
        
        # 如果没有可用动作
        if not any(action_availability):
            return [0.25, 0.25, 0.25, 0.25]
        
        # 构建概率分布
        probs = []
        for i, (action, is_available) in enumerate(zip(Action, action_availability)):
            if not is_available:
                probs.append(0.0)
            elif action == Action.RIGHT:
                probs.append(self.right_prob)
            else:
                probs.append(self.other_prob)
        
        return probs


class TD0PolicyEvaluator:
    """
    TD(0) 策略评估器
    
    评估给定策略π的状态价值函数 V(s)
    核心公式: V(s) ← V(s) + α [r + γV(s') - V(s)]
    """
    
    def __init__(self, 
                 env: GridWorldEnvironment,
                 policy: Policy,
                 learning_rate: float = 0.1,
                 discount_factor: float = 0.9) -> None:
        """
        初始化TD(0)策略评估器
        
        Args:
            env: 环境实例
            policy: 要评估的策略
            learning_rate: 学习率 (α)
            discount_factor: 折扣因子 (γ)
        """
        self.env = env
        self.policy = policy
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        
        # 状态价值函数 V(s)
        self.state_values = np.zeros((env.rows, env.cols))
        
        # 初始化特殊位置的价值
        self._initialize_special_state_values()
        
        # 训练统计
        self.episode_count = 0
        self.total_steps = 0
        self.value_update_history: List[Dict[str, Any]] = []
        
        # 性能指标
        self.episode_rewards: List[float] = []
        self.episode_steps: List[int] = []
    
    def _initialize_special_state_values(self) -> None:
        """初始化特殊状态的价值"""
        # 墙壁位置为负值(避免前往)
        for wall_x, wall_y in self.env.wall_positions:
            self.state_values[wall_x, wall_y] = -1.0
        
        # 陷阱位置为负值
        for trap_x, trap_y in self.env.trap_positions:
            self.state_values[trap_x, trap_y] = -1.0
        
        # 目标位置为正值
        goal_x, goal_y = self.env.goal_position
        self.state_values[goal_x, goal_y] = 1.0
    
    def evaluate_episode(self, max_steps: int = 500) -> Tuple[float, int]:
        """
        执行一轮策略评估
        
        Args:
            max_steps: 最大步数
            
        Returns:
            Tuple[累计奖励, 步数]
        """
        # 重置环境
        state = self.env.reset()
        self.env.update_portal_activation(self.episode_count)
        
        total_reward = 0.0
        steps = 0
        
        while steps < max_steps:
            # 根据策略选择动作
            action = self.policy.select_action(state)
            
            # 执行动作
            next_state, reward, is_done = self.env.step(action)
            
            # 使用TD(0)更新状态价值函数
            self._update_state_value(state, reward, next_state)
            
            # 更新统计
            total_reward += reward
            steps += 1
            self.total_steps += 1
            
            # 转移到下一状态
            state = next_state
            
            if is_done:
                break
        
        # 更新训练轮数
        self.episode_count += 1
        self.episode_rewards.append(total_reward)
        self.episode_steps.append(steps)
        
        return total_reward, steps
    
    def _update_state_value(self, state: Tuple[int, int], 
                           reward: float, 
                           next_state: Tuple[int, int]) -> float:
        """
        使用TD(0)更新状态价值函数
        
        Args:
            state: 当前状态
            reward: 即时奖励
            next_state: 下一状态
            
        Returns:
            float: TD误差
        """
        x, y = state
        nx, ny = next_state
        
        # 计算TD误差
        td_error = reward + self.discount_factor * self.state_values[nx, ny] - self.state_values[x, y]
        
        # 更新状态价值(墙壁不更新)
        if (x, y) not in self.env.wall_positions:
            self.state_values[x, y] += self.learning_rate * td_error
        
        # 记录更新历史
        self.value_update_history.append({
            'episode': self.episode_count,
            'state': state,
            'td_error': td_error,
            'new_value': self.state_values[x, y]
        })
        
        return td_error
    
    def get_evaluation_results(self) -> Dict[str, Any]:
        """
        获取策略评估结果
        
        Returns:
            Dict: 包含状态价值函数和评估统计
        """
        return {
            'state_values': self.state_values.copy(),
            'episode_count': self.episode_count,
            'total_steps': self.total_steps,
            'avg_reward': np.mean(self.episode_rewards) if self.episode_rewards else 0.0,
            'avg_steps': np.mean(self.episode_steps) if self.episode_steps else 0.0,
            'success_rate': self._calculate_success_rate(),
        }
    
    def _calculate_success_rate(self) -> float:
        """计算成功率"""
        if not self.episode_rewards:
            return 0.0
        successful_episodes = sum(1 for reward in self.episode_rewards if reward > 0)
        return successful_episodes / len(self.episode_rewards)
    
    def visualize_state_values(self, 
                              cmap: str = 'RdYlGn', 
                              figsize: Tuple[int, int] = (12, 8)) -> None:
        """
        可视化状态价值函数
        
        Args:
            cmap: 颜色映射
            figsize: 图形大小
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        
        # 左侧:热图
        self._plot_value_heatmap(ax1, cmap)
        
        # 右侧:数值表格
        self._plot_value_table(ax2)
        
        plt.suptitle(f'TD(0) 策略评估结果 (共{self.episode_count}轮)', fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    def _plot_value_heatmap(self, ax, cmap: str) -> None:
        """绘制价值热图"""
        # 创建掩码(隐藏墙壁)
        mask = np.zeros_like(self.state_values, dtype=bool)
        for wall_x, wall_y in self.env.wall_positions:
            mask[wall_x, wall_y] = True
        
        # 创建带掩码的数组
        display_values = np.ma.array(self.state_values, mask=mask)
        
        # 绘制热图
        im = ax.imshow(display_values, cmap=cmap, interpolation='nearest', aspect='auto')
        plt.colorbar(im, ax=ax, label='状态价值 V(s)')
        
        # 标记特殊位置
        self._mark_special_positions(ax)
        
        ax.set_title('状态价值热图')
        ax.set_xlabel('列 (y)')
        ax.set_ylabel('行 (x)')
        ax.set_xticks(range(self.env.cols))
        ax.set_yticks(range(self.env.rows))
    
    def _mark_special_positions(self, ax) -> None:
        """标记特殊位置"""
        # 开始位置
        start_x, start_y = self.env.start_position
        ax.scatter(start_y, start_x, color='gold', s=200, marker='*', 
                  edgecolor='black', linewidth=1, label='起点')
        
        # 目标位置
        goal_x, goal_y = self.env.goal_position
        ax.scatter(goal_y, goal_x, color='green', s=200, marker='*',
                  edgecolor='black', linewidth=1, label='目标')
        
        # 陷阱位置
        for trap_x, trap_y in self.env.trap_positions:
            ax.scatter(trap_y, trap_x, color='red', s=100, marker='x',
                      linewidth=2, label='陷阱' if trap_x == 0 and trap_y == 8 else "")
        
        ax.legend(loc='upper right')
    
    def _plot_value_table(self, ax) -> None:
        """绘制价值数值表格"""
        # 创建表格数据
        table_data = []
        for x in range(self.env.rows):
            row_data = []
            for y in range(self.env.cols):
                if (x, y) in self.env.wall_positions:
                    row_data.append('██')
                else:
                    row_data.append(f'{self.state_values[x, y]:.2f}')
            table_data.append(row_data)
        
        # 创建表格
        table = ax.table(cellText=table_data,
                        cellLoc='center',
                        loc='center',
                        colWidths=[0.1] * self.env.cols)
        
        # 设置表格样式
        table.auto_set_font_size(False)
        table.set_fontsize(9)
        table.scale(1, 1.5)
        
        # 使用颜色常量(使用matplotlib颜色名称或十六进制代码,避免除法运算)
        COLORS = {
            'start': 'gold',           # 金色 - 开始位置
            'goal': 'limegreen',       # 安全绿 - 目标位置  
            'trap': 'crimson',         # 警示红 - 陷阱位置
            'blue_portal': 'deepskyblue',  # 科技蓝 - 蓝色传送门
            'purple_portal': 'mediumpurple',  # 神秘紫 - 紫色传送门
            'wall': 'dimgray',         # 暗灰石墙 - 墙壁
            'positive': 'lightgreen',  # 浅绿 - 正值状态
            'negative': 'lightcoral',  # 浅红 - 负值状态
        }
        
        # 设置特殊单元格颜色
        for (x, y), cell in table.get_celld().items():
            if x == 0:  # 标题行
                continue
            
            actual_x, actual_y = x, y
            pos = (actual_x, actual_y)
            
            # 按照优先级设置颜色(后设置的会覆盖先设置的)
            if pos in self.env.wall_positions:
                cell.set_facecolor(COLORS['wall'])
                cell.set_text_props(weight='bold')
            elif pos == self.env.start_position:
                cell.set_facecolor(COLORS['start'])
            elif pos == self.env.goal_position:
                cell.set_facecolor(COLORS['goal'])
            elif pos in self.env.trap_positions:
                cell.set_facecolor(COLORS['trap'])
            elif self.state_values[actual_x, actual_y] > 0:
                # 对于有正价值的状态使用浅绿色
                cell.set_facecolor(COLORS['positive'])
            elif self.state_values[actual_x, actual_y] < 0:
                # 对于有负价值的状态使用浅红色
                cell.set_facecolor(COLORS['negative'])
        
        ax.set_title('状态价值表格')
        ax.axis('off')


class GridWorldExperiment:
    """
    网格世界实验管理器
    
    用于管理和执行策略评估实验
    """
    
    def __init__(self, 
                 env: GridWorldEnvironment,
                 policy: Optional[Policy] = None) -> None:
        """
        初始化实验
        
        Args:
            env: 环境实例
            policy: 策略实例,如果为None则使用随机策略
        """
        self.env = env
        
        # 使用默认策略或传入的策略
        
        self.policy = policy
        
        # 创建策略评估器
        self.evaluator = TD0PolicyEvaluator(env, self.policy)
        
        # 性能统计
        self.moving_avg_rewards: List[float] = []
        self.moving_avg_steps: List[int] = []
        self.success_rates: List[float] = []
    
    def run_evaluation(self, 
                      num_episodes: int = 1000,
                      progress_interval: int = 100) -> Dict[str, Any]:
        """
        运行策略评估实验
        
        Args:
            num_episodes: 评估轮数
            progress_interval: 进度打印间隔
            
        Returns:
            Dict: 评估结果
        """
        print(f"开始策略评估,总轮数: {num_episodes}")
        print("=" * 60)
        
        for episode in range(num_episodes):
            # 单轮评估
            reward, steps = self.evaluator.evaluate_episode()
            
            # 更新性能统计
            self._update_performance_metrics(episode, reward, steps)
            
            # 定期输出进度
            if (episode + 1) % progress_interval == 0:
                self._print_progress(episode + 1)
        
        print("\n策略评估完成!")
        print("=" * 60)
        
        return self._compile_results()
    
    def _update_performance_metrics(self, 
                                   episode: int, 
                                   reward: float, 
                                   steps: int) -> None:
        """更新性能指标"""
        window_size = 50
        
        # 滑动平均奖励
        if episode >= window_size - 1:
            recent_rewards = self.evaluator.episode_rewards[episode-window_size+1:episode+1]
            self.moving_avg_rewards.append(np.mean(recent_rewards))
        
       
        
        # 成功率(整个历史)
        if episode >= 10:  # 至少有10轮数据
            success_rate = self.evaluator._calculate_success_rate()
            self.success_rates.append(success_rate)
    
    def _print_progress(self, episode: int) -> None:
        """打印训练进度"""
        # 计算最近100轮的表现
        start_idx = max(0, episode - 100)
        recent_rewards = self.evaluator.episode_rewards[start_idx:episode]
        recent_steps = self.evaluator.episode_steps[start_idx:episode]
        
        avg_reward = np.mean(recent_rewards) if recent_rewards else 0.0
        avg_steps = np.mean(recent_steps) if recent_steps else 0.0
        success_rate = np.mean([r > 0 for r in recent_rewards]) * 100 if recent_rewards else 0.0
        
        print(f"轮数 {episode:4d} | "
              f"平均奖励: {avg_reward:6.3f} | "
              f"平均步数: {avg_steps:5.1f} | "
              f"成功率: {success_rate:5.1f}%")
    
    def _compile_results(self) -> Dict[str, Any]:
        """编译评估结果"""
        results = self.evaluator.get_evaluation_results()
        
        # 添加更多统计信息
        results.update({
            'moving_avg_rewards': self.moving_avg_rewards,
            'moving_avg_steps': self.moving_avg_steps,
            'success_rates': self.success_rates,
            'policy_type': self.policy.__class__.__name__,
        })
        
        return results
    
    def visualize_results(self, 
                         save_path: Optional[str] = None,
                         figsize: Tuple[int, int] = (15, 10)) -> None:
        """
        可视化评估结果
        
        Args:
            save_path: 保存图片的路径
            figsize: 图形大小
        """
        plt.figure(figsize=figsize)
        
        # 1. 奖励曲线
        ax1 = plt.subplot(2, 2, 1)
        self._plot_reward_curve(ax1)
        
        # 2. 步数曲线
        ax2 = plt.subplot(2, 2, 2)
        self._plot_step_curve(ax2)
        
        # 3. 成功率曲线
        ax3 = plt.subplot(2, 2, 3)
        self._plot_success_rate_curve(ax3)
        
        # 4. 状态价值函数
        ax4 = plt.subplot(2, 2, 4)
        self._plot_state_values(ax4)
        
        plt.suptitle(f'TD(0) 策略评估实验结果 ({self.policy.__class__.__name__})', 
                    fontsize=16, fontweight='bold')
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"结果图表已保存至: {save_path}")
        
        plt.show()
    
    def _plot_reward_curve(self, ax) -> None:
        """绘制奖励曲线"""
        episodes = range(1, len(self.evaluator.episode_rewards) + 1)
        
        # 原始奖励
        ax.plot(episodes, self.evaluator.episode_rewards, 'b-', 
               alpha=0.3, label='每轮奖励', linewidth=0.5)
        
        # 滑动平均奖励
        if self.moving_avg_rewards:
            moving_episodes = range(50, len(self.moving_avg_rewards) + 50)
            ax.plot(moving_episodes, self.moving_avg_rewards, 'r-', 
                   linewidth=2, label='滑动平均 (窗口=50)')
        
        # 标记紫色传送门激活点
        if len(self.evaluator.episode_rewards) > 100:
            ax.axvline(x=100, color='purple', linestyle='--', 
                      label='紫色传送门激活', alpha=0.7)
        
        ax.set_xlabel('训练轮数')
        ax.set_ylabel('累计奖励')
        ax.set_title('奖励曲线')
        ax.grid(True, alpha=0.3)
        ax.legend(loc='upper left')
    
    def _plot_step_curve(self, ax) -> None:
        """绘制步数曲线"""
        episodes = range(1, len(self.evaluator.episode_steps) + 1)
        
        ax.plot(episodes, self.evaluator.episode_steps, 'g-', alpha=0.6)
        
        if self.moving_avg_steps:
            moving_episodes = range(50, len(self.moving_avg_steps) + 50)
            ax.plot(moving_episodes, self.moving_avg_steps, 'orange', 
                   linewidth=2, label='平均步数 (窗口=50)')
            ax.legend()
        
        ax.set_xlabel('训练轮数')
        ax.set_ylabel('每轮步数')
        ax.set_title('步数曲线')
        ax.grid(True, alpha=0.3)
    
    def _plot_success_rate_curve(self, ax) -> None:
        """绘制成功率曲线"""
        if self.success_rates:
            episodes = range(10, len(self.success_rates) + 10)
            ax.plot(episodes, self.success_rates, 'm-', linewidth=2)
            
            if len(self.success_rates) > 100:
                ax.axvline(x=100, color='purple', linestyle='--', alpha=0.7)
        
        ax.set_xlabel('训练轮数')
        ax.set_ylabel('成功率')
        ax.set_title('成功率曲线')
        ax.set_ylim([0, 1.05])
        ax.grid(True, alpha=0.3)
    
    def _plot_state_values(self, ax) -> None:
        """绘制状态价值函数"""
        # 使用策略评估器的可视化方法
        self.evaluator._plot_value_heatmap(ax, 'RdYlGn')
        ax.set_title('最终状态价值')


def run_experiment(policy_type: str = 'random', 
                  num_episodes: int = 500) -> None:
    """
    运行完整的策略评估实验
    
    Args:
        policy_type: 策略类型 ('random' 或 'right_preferential')
        num_episodes: 评估轮数
    """
    print("=" * 60)
    print("网格世界强化学习实验 - TD(0)策略评估")
    print("=" * 60)
    
    # 1. 创建环境
    print("\n1. 创建网格世界环境...")
    env = GridWorldEnvironment()
    env.render(figsize=(12, 8))
    
    # 2. 创建策略
    print(f"\n2. 创建{policy_type}策略...")
    
    policy = RightPreferentialPolicy(env, right_prob=0.7)
   
    
    # 3. 创建并运行实验
    print(f"\n3. 创建实验并运行{num_episodes}轮评估...")
    experiment = GridWorldExperiment(env, policy)
    results = experiment.run_evaluation(num_episodes=num_episodes)
    
    # 4. 显示评估结果摘要
    print("\n4. 评估结果摘要:")
    print("-" * 40)
    
    summary_items = [
        ('总轮数', results['episode_count']),
        ('总步数', results['total_steps']),
        ('平均奖励', f"{results['avg_reward']:.3f}"),
        ('平均步数', f"{results['avg_steps']:.1f}"),
        ('成功率', f"{results['success_rate']:.1%}"),
        ('策略类型', results['policy_type']),
    ]
    
    for label, value in summary_items:
        print(f"  {label:<10}: {value}")
    
    print("-" * 40)
    
    # 5. 可视化结果
    print("\n5. 生成可视化图表...")
    experiment.visualize_results(save_path='td0_policy_evaluation.png')
    
    # 6. 可视化状态价值函数
    print("\n6. 显示状态价值函数...")
    experiment.evaluator.visualize_state_values()
    
    print("\n实验完成!")
    print("=" * 60)


def main():
    """主函数"""
    # 设置随机种子以保证可复现性
    np.random.seed(42)
    random.seed(42)
    
    # 运行实验
    run_experiment(policy_type='random', num_episodes=500)


if __name__ == "__main__":
    main()
相关推荐
我只会写Bug啊2 小时前
B站/爱奇艺防录屏防截屏原理及Vue3实战实现
前端·软件开发
蜗牛攻城狮2 小时前
前端构建工具详解:Vite 与 Webpack 深度对比与实战指南
前端·webpack·vite·构建工具
IT_陈寒2 小时前
Redis 性能翻倍的 5 个冷门技巧,90%开发者都不知道的底层优化!
前端·人工智能·后端
亿牛云爬虫专家2 小时前
当数据开始“感知页面”
javascript·html·爬虫代理·代理ip·playwright·页面渲染·dom结构
Umi·2 小时前
shell 条件测试
linux·前端·javascript
第二只羽毛2 小时前
基于Deep Web爬虫的当当网图书信息采集
大数据·开发语言·前端·爬虫·算法
北极象2 小时前
CEF 与 Electron简单对比
前端·javascript·electron
小天博客2 小时前
向后端发起POST请求
开发语言·前端·javascript