时间差分学习是蒙特卡罗(MC)方法和动态规划(DP)方法的结合:
- 与 MC 方法类似,TD 方法可以从经验中学习,而无需 环境动力学模型。
- 与 DP 方法类似,TD 方法在每一步之后都会根据其他学习到的估计值 更新估计值, 而无需等待结果 (这被称为引导)。
TD 方法的一个特点是,它在每个时间步更新其值估计,而 MC 方法则要等到回合结束才更新。
实际上,这两种方法的更新目标不同 。MC 方法旨在更新收益Gt,而 Gt 仅在每轮迭代结束时可用。相比之下,TD 方法的目标是:

其中V 是真实值函数 Vπ 的估计值。
因此,TD 方法结合了 MC 的抽样 ( 通过使用真实值的估计)和DP 的自助法(通过基于依赖于进一步估计的估计来更新 V)。
最后给出一个TD0 实现世界环境探索的例子
目录
- TD 算法例子
- TD 收敛性定理
- TD 和 MC 的比较
- TD0 Python 例子(TD0 只有策略评估)
一 TD 算法例子
如下机器人走迷宫的例子:
环境变量:初始状态值

1: 当Robot 从 移动到
,我们有
TD target :
TD error:
state value
2: 当robot 从 移动到
,我们有
TD target :
TD error:
state value
3: 按照上面的步骤反复迭代更新state value
二 TD 收敛性定理

2.1 说明
TD 算法实质上是求解给定策略的Bellman 公式
其中G 是 discount return
其中代表 next state,公式(4)可以写作
公式5也称为 Bellman equation , 也称为 Bellman expectation equation
前面一章知道该公式可以RM算法求解
2.2 定理解读
-
结论 :在满足特定条件下,TD算法(公式7.1,本质上是TD(0))几乎必然(almost surely) 收敛到真实的状态价值函数
。这意味着随着时间推移,我们的估计
将以概率1无限逼近真值。(证明略)
-
关键条件:收敛需要两个关于学习率 αt(s) 的经典条件:
:保证有"足够长"的学习时间,最终能克服初始误差和随机性。
:保证学习步长最终会变得足够小,以平息更新中的随机噪声,使估计能稳定下来。
-
与状态的关联 :注意学习率是
,与状态相关 。条件要求对每个状态 s 都成立,这引出了下一个关键点。
这部分连接了理论条件和实际应用。
-
"无限次访问"要求 :因为 αt(s)>0 仅当状态 s 在时刻 t 被访问,所以条件
等价于要求 状态 s 在无限的时间步中被访问无限次。
- 理论保障 :这解释了为什么强化学习需要充分探索 。无论通过 探索性起点 还是采用随机性策略(如 ε-greedy),目标都是确保所有状态能被无限次访问,避免因"没见过"而无法学习。
-
常数学习率的妥协 :在实践中,常使用小的常数学习率 α ,但这违反了
的条件(因为常数平方和发散)。
- 理论与实践的桥梁 :文本指出,即使使用常数学习率,算法仍能收敛,但收敛形式变为 "在期望意义上收敛"。这意味着估计值会在真实值附近波动,而波动的期望为零。这是一种更弱但更实用的收敛保证。
三 TD 和 MC 的比较
1. Incremental vs. Non-incremental:
-
TD(Incremental ): 在每一步
之后都可以立即更新价值估计。这使它适用于在线学习,即边交互边学习,比如自动驾驶汽车的实时学习。
-
MC(非增量): 必须等待整个episode(从开始到终止状态的一次完整经历)结束后,得到累计回报
Gt,才能进行更新。这意味着学习是离线的、批处理的。
2. Continuing tasks vs Episodic tasks:
-
TD(适用于Continuing tasks): 因其增量性,TD方法不依赖于"Episodic end "的概念。对于没有明确终点的连续任务(如长期运行的机器人控制),TD可以通过持续更新来学习。
-
MC(Episodic tasks:): 其核心依赖于计算从某个状态到Episodic end 的总回报。如果没有Episodic end ,
Gt就无法定义,因此MC无法直接应用于连续任务。
3. Bootstrapping vs Non-Bootstrapping:
-
TD(自举): TD的更新目标(如
)依赖于当前的价值估计值。这就像"拽着自己的鞋带把自己拉起来"------用现有(可能不准确)的估计来更新自己。
-
优点: 可以更快地利用已有知识,传播信息效率高。
-
缺点: **对初始值敏感(**如果初始猜测很差,需要时间纠正),且最终估计是有偏的。
-
-
MC(非自举): MC的更新目标
Gt是完全通过实际采样得到的回报,不依赖于任何价值函数的初始估计。-
优点: 估计最终是无偏的,收敛到的解不受初始值影响。
-
缺点: 必须等幕结束,且初期信息传播慢。
-
4. Low estimation variance vs High estimation variance:
这是最关键、最实际的差异之一。
-
TD(Low estimation variance ): TD的更新只涉及少量随机变量。以Sarsa为例,更新
q(St, At)只依赖于下一个即时奖励Rt+1、下一个状态St+1和下一个动作At+1。这些随机性的来源较少,因此单次更新的波动较小,学习过程更平稳。 -
MC(High estimation variance): MC的更新目标
Gt是未来一长串随机变量(奖励、状态转移、动作选择)的累加和。每一个环节的随机性都会累积到最终回报中。如文中所述,一个长度为L的幕,在软策略下有|A|^L种可能路径。仅用少数几条路径的回报来估计价值,其波动性(方差)必然会非常高,导致学习不稳定、收敛慢。
五 TD0 Python 例子(TD0 只有策略评估)
本实验将使用一个网格世界环境,该环境具有以下特征:
- 网格为 12×8 个单元格。
- 智能体从网格的左下角开始,目标是到达位于右上角的宝藏(奖励为 1 的终止状态)。
- 蓝色传送门是连通的,穿过位于单元格(6, 10)的传送门 会到达单元格(0, 11)。智能体在第一次传送后不能再次穿过该传送门。
- 紫色传送门 仅在100集后出现,但可以让特工更快地找到宝藏。这鼓励玩家不断探索环境。
- 红色传送门是陷阱(奖励为 0 的终止状态),会结束本集。
- 撞到墙壁会导致智能体保持原状。
#配色方案:
开始位置 - 金色标记(珍贵起点) RGB: (255, 215, 0)
目标位置 - 安全绿(安心到达) RGB: (50, 205, 50)
陷阱位置 - 警示红(危险警告) RGB: (220, 20, 60)
蓝色传送门 - 科技蓝(稳定通道) RGB: (0, 191, 255)
紫色传送门 - 神秘紫(随机通道) RGB: (147, 112, 219)
墙壁 - 暗灰石墙(环境屏障) RGB: (105, 105, 105)

动作定义

环境说明



# -*- coding: utf-8 -*-
"""
基于TD(0)算法的网格世界强化学习实现
Grid World Reinforcement Learning using TD(0) Algorithm
作者: chengxf2
日期: 2025年12月
描述: 实现TD(0)算法在复杂网格世界环境中的策略评估
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import Tuple, List, Dict, Any, Optional
import random
from enum import IntEnum
class Action(IntEnum):
"""动作枚举类"""
UP = 0
DOWN = 1
LEFT = 2
RIGHT = 3
# 颜色定义常量
COLORS = {
'start': (255, 215, 0), # 金色 - 开始位置
'goal': (50, 205, 50), # 安全绿 - 目标位置
'trap': (220, 20, 60), # 警示红 - 陷阱位置
'blue_portal': (0, 191, 255), # 科技蓝 - 蓝色传送门
'purple_portal': (147, 112, 219), # 神秘紫 - 紫色传送门
'wall': (105, 105, 105), # 暗灰石墙 - 墙壁
}
class GridWorldEnvironment:
"""
网格世界环境类
8×12网格,包含传送门、陷阱和动态元素
坐标系: (行, 列) 或 (x, y),其中x从上到下,y从左到右
"""
def __init__(self) -> None:
"""初始化网格世界环境参数"""
# 网格尺寸
self.rows: int = 8
self.cols: int = 12
# 特殊位置定义
self.start_position: Tuple[int, int] = (7, 0) # 开始位置
self.goal_position: Tuple[int, int] = (2, 9) # 目标位置
# 陷阱位置
self.trap_positions: List[Tuple[int, int]] = [
(0, 8), (1, 8), (2, 8), (3, 8),
(3, 9), (3, 10), (3, 11)
]
# 蓝色传送门
self.blue_portal_entrance: Tuple[int, int] = (6, 10)
self.blue_portal_exit: Tuple[int, int] = (0, 11)
self.is_blue_portal_used: bool = False
# 紫色传送门(100轮后激活)
self.is_purple_portal_active: bool = False
self.purple_portal_entrance: Tuple[int, int] = (2, 1)
self.purple_portal_exit: List[Tuple[int, int]] = [
(0, 9), (0, 10), (0, 11),
(1, 9), (1, 10), (1, 11),
(2, 9), (2, 10), (2, 11)
]
# 墙壁位置
self.wall_positions: List[Tuple[int, int]] = [
(5, 3), (6, 3), (7, 3), # 第一组障碍物
(3, 0), (3, 1), (3, 2), # 第二组障碍物
]
# 当前状态和环境统计
self.current_state: Tuple[int, int] = self.start_position
self.total_episodes: int = 0
self.step_count: int = 0
# 动作映射
self.action_mapping = {
Action.UP: self._move_up,
Action.DOWN: self._move_down,
Action.LEFT: self._move_left,
Action.RIGHT: self._move_right,
}
def reset(self) -> Tuple[int, int]:
"""
重置环境到初始状态
Returns:
Tuple[int, int]: 初始状态坐标
"""
self.current_state = self.start_position
self.is_blue_portal_used = False
self.step_count = 0
return self.current_state
def step(self, action: Action) -> Tuple[Tuple[int, int], float, bool]:
"""
执行动作并返回新的状态、奖励和终止标志
Args:
action: 动作枚举
Returns:
Tuple[状态, 奖励, 是否终止]
"""
current_x, current_y = self.current_state
self.step_count += 1
# 计算移动后的新位置
new_x, new_y = self._calculate_new_position(current_x, current_y, action)
# 检查墙壁碰撞
if (new_x, new_y) in self.wall_positions:
new_x, new_y = current_x, current_y
# 检查陷阱
if (new_x, new_y) in self.trap_positions:
self.current_state = (new_x, new_y)
return self.current_state, -1.0, True
# 检查目标
if (new_x, new_y) == self.goal_position:
self.current_state = (new_x, new_y)
return self.current_state, 1.0, True
# 检查传送门
new_x, new_y = self._process_portals(new_x, new_y)
# 更新当前状态
self.current_state = (new_x, new_y)
# 默认奖励为0(稀疏奖励设置)
return self.current_state, 0.0, False
def _calculate_new_position(self, x: int, y: int, action: Action) -> Tuple[int, int]:
"""
计算移动后的新位置
Args:
x: 当前x坐标
y: 当前y坐标
action: 动作
Returns:
Tuple[int, int]: 新位置坐标
"""
move_func = self.action_mapping.get(action)
if move_func is None:
raise ValueError(f"无效动作: {action}")
return move_func(x, y)
def _move_up(self, x: int, y: int) -> Tuple[int, int]:
"""向上移动"""
new_x = max(x - 1, 0)
return new_x, y
def _move_down(self, x: int, y: int) -> Tuple[int, int]:
"""向下移动"""
new_x = min(x + 1, self.rows - 1)
return new_x, y
def _move_left(self, x: int, y: int) -> Tuple[int, int]:
"""向左移动"""
new_y = max(y - 1, 0)
return x, new_y
def _move_right(self, x: int, y: int) -> Tuple[int, int]:
"""向右移动"""
new_y = min(y + 1, self.cols - 1)
return x, new_y
def _process_portals(self, x: int, y: int) -> Tuple[int, int]:
"""
检查并处理传送门
Args:
x: 当前x坐标
y: 当前y坐标
Returns:
Tuple[int, int]: 处理后的位置
"""
# 蓝色传送门(仅可使用一次)
if (not self.is_blue_portal_used) and ((x, y) == self.blue_portal_entrance):
self.is_blue_portal_used = True
return self.blue_portal_exit
# 紫色传送门(100轮后激活)
elif self.is_purple_portal_active and ((x, y) == self.purple_portal_entrance):
return random.choice(self.purple_portal_exit)
# 不是传送门
return (x, y)
def step_simulate(self, state: Tuple[int, int], action: Action) -> Tuple[Tuple[int, int], float, bool]:
"""
模拟执行动作而不改变环境状态(用于规划)
Args:
state: 当前状态
action: 动作
Returns:
Tuple[下一状态, 奖励, 是否终止]
"""
x, y = state
# 计算新位置
new_x, new_y = self._calculate_new_position(x, y, action)
# 检查墙壁
if (new_x, new_y) in self.wall_positions:
new_x, new_y = x, y
# 检查陷阱
if (new_x, new_y) in self.trap_positions:
return (new_x, new_y), -1.0, True
# 检查目标
if (new_x, new_y) == self.goal_position:
return (new_x, new_y), 1.0, True
# 检查传送门
if (not self.is_blue_portal_used) and ((new_x, new_y) == self.blue_portal_entrance):
new_x, new_y = self.blue_portal_exit
if self.is_purple_portal_active and ((new_x, new_y) == self.purple_portal_entrance):
new_x, new_y = random.choice(self.purple_portal_exit)
# 再次检查是否撞墙(传送后可能撞墙)
if (new_x, new_y) in self.wall_positions:
new_x, new_y = x, y
return (new_x, new_y), 0.0, False
def update_portal_activation(self, episode_count: int) -> None:
"""
根据训练轮数更新传送门激活状态
Args:
episode_count: 当前训练轮数
"""
self.is_purple_portal_active = (episode_count >= 100)
def render(self, figsize: Tuple[int, int] = (10, 8), show_legend: bool = True) -> None:
"""
可视化环境状态
Args:
figsize: 图形大小
show_legend: 是否显示图例
"""
self._setup_matplotlib_fonts()
# 创建图形和坐标轴
fig, ax = plt.subplots(figsize=figsize)
# 设置坐标轴
self._setup_axes(ax)
# 绘制网格和各个元素
self._draw_grid_background(ax)
self._draw_walls(ax)
self._draw_start_position(ax)
self._draw_goal_position(ax)
self._draw_traps(ax)
self._draw_blue_portal(ax)
self._draw_purple_portal(ax)
self._draw_agent(ax)
self._draw_grid_coordinates(ax)
# 添加标题和图例
self._add_title_and_legend(ax, show_legend)
plt.tight_layout()
plt.show()
# 打印文本版本
self._print_text_version()
def _setup_matplotlib_fonts(self) -> None:
"""设置matplotlib字体"""
try:
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
except Exception:
pass
plt.rcParams['axes.unicode_minus'] = False
def _setup_axes(self, ax) -> None:
"""设置坐标轴"""
ax.set_xlim(-0.5, self.cols - 0.5)
ax.set_ylim(-0.5, self.rows - 0.5)
ax.set_aspect('equal')
ax.set_xticks([])
ax.set_yticks([])
ax.invert_yaxis() # 让y轴向下为正
def _draw_grid_background(self, ax) -> None:
"""绘制网格背景"""
for x in range(self.rows):
for y in range(self.cols):
rect = patches.Rectangle(
(y - 0.5, x - 0.5), 1, 1,
linewidth=1,
edgecolor=(0.78, 0.78, 0.78), # 浅灰色
facecolor=(0.94, 0.94, 0.94), # 更浅的灰色
)
ax.add_patch(rect)
def _draw_walls(self, ax) -> None:
"""绘制墙壁"""
for wall_x, wall_y in self.wall_positions:
rect = patches.Rectangle(
(wall_y - 0.5, wall_x - 0.5), 1, 1,
linewidth=1,
edgecolor=self._normalize_color(COLORS['wall']),
facecolor=self._normalize_color(COLORS['wall']),
hatch='////'
)
ax.add_patch(rect)
ax.text(wall_y, wall_x, '█', ha='center', va='center',
fontsize=14, color='white', fontweight='bold')
def _draw_start_position(self, ax) -> None:
"""绘制开始位置"""
start_x, start_y = self.start_position
rect = patches.Rectangle(
(start_y - 0.5, start_x - 0.5), 1, 1,
linewidth=2,
edgecolor='black',
facecolor=self._normalize_color(COLORS['start'])
)
ax.add_patch(rect)
ax.text(start_y, start_x, 'S', ha='center', va='center',
fontsize=12, color='black', fontweight='bold')
def _draw_goal_position(self, ax) -> None:
"""绘制目标位置"""
goal_x, goal_y = self.goal_position
rect = patches.Rectangle(
(goal_y - 0.5, goal_x - 0.5), 1, 1,
linewidth=2,
edgecolor='black',
facecolor=self._normalize_color(COLORS['goal'])
)
ax.add_patch(rect)
ax.text(goal_y, goal_x, 'G', ha='center', va='center',
fontsize=12, color='white', fontweight='bold')
def _draw_traps(self, ax) -> None:
"""绘制陷阱位置"""
for trap_x, trap_y in self.trap_positions:
rect = patches.Rectangle(
(trap_y - 0.5, trap_x - 0.5), 1, 1,
linewidth=1,
edgecolor=self._normalize_color(COLORS['trap']),
facecolor=self._normalize_color(COLORS['trap'])
)
ax.add_patch(rect)
ax.text(trap_y, trap_x, 'T', ha='center', va='center',
fontsize=10, color='white', fontweight='bold')
def _draw_blue_portal(self, ax) -> None:
"""绘制蓝色传送门"""
# 入口
blue_x, blue_y = self.blue_portal_entrance
rect = patches.Rectangle(
(blue_y - 0.5, blue_x - 0.5), 1, 1,
linewidth=2,
edgecolor='black',
facecolor=self._normalize_color(COLORS['blue_portal']),
alpha=0.8
)
ax.add_patch(rect)
ax.text(blue_y, blue_x, 'B', ha='center', va='center',
fontsize=12, color='white', fontweight='bold')
# 出口
blue_exit_x, blue_exit_y = self.blue_portal_exit
rect = patches.Rectangle(
(blue_exit_y - 0.5, blue_exit_x - 0.5), 1, 1,
linewidth=2,
edgecolor=self._normalize_color(COLORS['blue_portal']),
facecolor='white',
linestyle='--',
alpha=0.6
)
ax.add_patch(rect)
ax.text(blue_exit_y, blue_exit_x, 'BE', ha='center', va='center',
fontsize=10, color=self._normalize_color(COLORS['blue_portal']),
fontweight='bold')
def _draw_purple_portal(self, ax) -> None:
"""绘制紫色传送门"""
if not self.is_purple_portal_active:
return
purple_x, purple_y = self.purple_portal_entrance
rect = patches.Rectangle(
(purple_y - 0.5, purple_x - 0.5), 1, 1,
linewidth=2,
edgecolor='black',
facecolor=self._normalize_color(COLORS['purple_portal']),
alpha=0.8
)
ax.add_patch(rect)
ax.text(purple_y, purple_x, 'P', ha='center', va='center',
fontsize=12, color='white', fontweight='bold')
# 出口区域
for exit_x, exit_y in self.purple_portal_exit:
rect = patches.Rectangle(
(exit_y - 0.5, exit_x - 0.5), 1, 1,
linewidth=1,
edgecolor=self._normalize_color(COLORS['purple_portal']),
facecolor='white',
linestyle='--',
alpha=0.4
)
ax.add_patch(rect)
def _draw_agent(self, ax) -> None:
"""绘制智能体"""
agent_x, agent_y = self.current_state
circle = patches.Circle(
(agent_y, agent_x), 0.3,
linewidth=2,
edgecolor='black',
facecolor=(1.0, 0.65, 0.0) # 橙色
)
ax.add_patch(circle)
ax.text(agent_y, agent_x, 'A', ha='center', va='center',
fontsize=10, color='black', fontweight='bold')
def _draw_grid_coordinates(self, ax) -> None:
"""绘制网格坐标"""
for x in range(self.rows):
for y in range(self.cols):
cell = (x, y)
if self._is_empty_cell(cell):
ax.text(y, x, f'({x},{y})', ha='center', va='center',
fontsize=6, color='black', alpha=0.5)
def _add_title_and_legend(self, ax, show_legend: bool) -> None:
"""添加标题和图例"""
title = f"网格世界环境\n轮数: {self.total_episodes}, 步数: {self.step_count}"
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
if show_legend:
legend_elements = [
patches.Patch(facecolor=self._normalize_color(COLORS['start']),
edgecolor='black', label='起点 (Start)'),
patches.Patch(facecolor=self._normalize_color(COLORS['goal']),
edgecolor='black', label='目标 (Goal)'),
patches.Patch(facecolor=self._normalize_color(COLORS['trap']),
edgecolor='black', label='陷阱 (Trap)'),
patches.Patch(facecolor=self._normalize_color(COLORS['blue_portal']),
edgecolor='black', label='蓝色传送门'),
patches.Patch(facecolor=self._normalize_color(COLORS['purple_portal']),
edgecolor='black', label='紫色传送门'),
patches.Patch(facecolor=self._normalize_color(COLORS['wall']),
edgecolor='black', label='墙壁 (Wall)'),
patches.Patch(facecolor=(1.0, 0.65, 0.0),
edgecolor='black', label='智能体 (Agent)'),
]
ax.legend(handles=legend_elements,
loc='upper right',
bbox_to_anchor=(1.15, 1),
fontsize=9,
title='图例说明',
title_fontsize=10)
# 添加边框
for spine in ax.spines.values():
spine.set_linewidth(2)
spine.set_color('black')
def _print_text_version(self) -> None:
"""打印文本版本的环境状态"""
print(f"\n当前环境状态 (轮数: {self.total_episodes}, 步数: {self.step_count})")
if self.total_episodes >= 100:
print("紫色传送门已激活")
else:
print("紫色传送门未激活")
def _is_empty_cell(self, cell: Tuple[int, int]) -> bool:
"""判断是否为空白单元格"""
x, y = cell
return (cell not in self.wall_positions and
cell != self.start_position and
cell != self.goal_position and
cell not in self.trap_positions and
cell != self.blue_portal_entrance and
not (self.is_purple_portal_active and cell == self.purple_portal_entrance) and
cell != self.current_state)
@staticmethod
def _normalize_color(rgb_tuple: Tuple[int, int, int]) -> Tuple[float, float, float]:
"""将RGB颜色从0-255范围归一化到0-1范围"""
return tuple(c / 255.0 for c in rgb_tuple)
def get_environment_summary(self) -> Dict[str, Any]:
"""获取环境摘要信息"""
return {
'dimensions': (self.rows, self.cols),
'start_position': self.start_position,
'goal_position': self.goal_position,
'trap_count': len(self.trap_positions),
'wall_count': len(self.wall_positions),
'has_blue_portal': True,
'has_purple_portal': True,
'is_purple_portal_active': self.is_purple_portal_active
}
class Policy:
"""策略基类"""
def __init__(self, env: GridWorldEnvironment):
self.env = env
def get_action_probabilities(self, state: Tuple[int, int]) -> List[float]:
"""获取动作概率分布"""
raise NotImplementedError
def _normalize_probabilities(self, probs: List[float]) -> List[float]:
"""归一化概率分布,确保和为1"""
if not probs:
return [0.25, 0.25, 0.25, 0.25]
total = sum(probs)
# 如果概率和为0,返回均匀分布
if total == 0:
return [0.25, 0.25, 0.25, 0.25]
# 归一化
normalized = [p / total for p in probs]
# 处理浮点误差
return normalized
def select_action(self, state: Tuple[int, int]) -> Action:
"""根据策略选择动作"""
action_probs = self.get_action_probabilities(state)
# 归一化概率
normalized_probs = self._normalize_probabilities(action_probs)
# 验证概率和是否为1(允许小的浮点误差)
prob_sum = sum(normalized_probs)
if abs(prob_sum - 1.0) > 1e-10:
normalized_probs = [p / prob_sum for p in normalized_probs]
return Action(np.random.choice(len(Action), p=normalized_probs))
class RightPreferentialPolicy(Policy):
"""向右偏好策略:向右动作概率最高"""
def __init__(self, env: GridWorldEnvironment, right_prob: float = 0.7):
super().__init__(env)
if not 0 <= right_prob <= 1:
raise ValueError(f"right_prob必须在0和1之间,当前为{right_prob}")
self.right_prob = right_prob
self.other_prob = (1.0 - right_prob) / 3.0
def get_action_probabilities(self, state: Tuple[int, int]) -> List[float]:
# 检查是否为终止状态
if (state in self.env.wall_positions or
state in self.env.trap_positions or
state == self.env.goal_position):
return [0.25, 0.25, 0.25, 0.25]
# 检查每个动作是否可用
action_availability = []
for action in Action:
next_state, _, _ = self.env.step_simulate(state, action)
is_available = next_state not in self.env.wall_positions
action_availability.append(is_available)
# 如果没有可用动作
if not any(action_availability):
return [0.25, 0.25, 0.25, 0.25]
# 构建概率分布
probs = []
for i, (action, is_available) in enumerate(zip(Action, action_availability)):
if not is_available:
probs.append(0.0)
elif action == Action.RIGHT:
probs.append(self.right_prob)
else:
probs.append(self.other_prob)
return probs
class TD0PolicyEvaluator:
"""
TD(0) 策略评估器
评估给定策略π的状态价值函数 V(s)
核心公式: V(s) ← V(s) + α [r + γV(s') - V(s)]
"""
def __init__(self,
env: GridWorldEnvironment,
policy: Policy,
learning_rate: float = 0.1,
discount_factor: float = 0.9) -> None:
"""
初始化TD(0)策略评估器
Args:
env: 环境实例
policy: 要评估的策略
learning_rate: 学习率 (α)
discount_factor: 折扣因子 (γ)
"""
self.env = env
self.policy = policy
self.learning_rate = learning_rate
self.discount_factor = discount_factor
# 状态价值函数 V(s)
self.state_values = np.zeros((env.rows, env.cols))
# 初始化特殊位置的价值
self._initialize_special_state_values()
# 训练统计
self.episode_count = 0
self.total_steps = 0
self.value_update_history: List[Dict[str, Any]] = []
# 性能指标
self.episode_rewards: List[float] = []
self.episode_steps: List[int] = []
def _initialize_special_state_values(self) -> None:
"""初始化特殊状态的价值"""
# 墙壁位置为负值(避免前往)
for wall_x, wall_y in self.env.wall_positions:
self.state_values[wall_x, wall_y] = -1.0
# 陷阱位置为负值
for trap_x, trap_y in self.env.trap_positions:
self.state_values[trap_x, trap_y] = -1.0
# 目标位置为正值
goal_x, goal_y = self.env.goal_position
self.state_values[goal_x, goal_y] = 1.0
def evaluate_episode(self, max_steps: int = 500) -> Tuple[float, int]:
"""
执行一轮策略评估
Args:
max_steps: 最大步数
Returns:
Tuple[累计奖励, 步数]
"""
# 重置环境
state = self.env.reset()
self.env.update_portal_activation(self.episode_count)
total_reward = 0.0
steps = 0
while steps < max_steps:
# 根据策略选择动作
action = self.policy.select_action(state)
# 执行动作
next_state, reward, is_done = self.env.step(action)
# 使用TD(0)更新状态价值函数
self._update_state_value(state, reward, next_state)
# 更新统计
total_reward += reward
steps += 1
self.total_steps += 1
# 转移到下一状态
state = next_state
if is_done:
break
# 更新训练轮数
self.episode_count += 1
self.episode_rewards.append(total_reward)
self.episode_steps.append(steps)
return total_reward, steps
def _update_state_value(self, state: Tuple[int, int],
reward: float,
next_state: Tuple[int, int]) -> float:
"""
使用TD(0)更新状态价值函数
Args:
state: 当前状态
reward: 即时奖励
next_state: 下一状态
Returns:
float: TD误差
"""
x, y = state
nx, ny = next_state
# 计算TD误差
td_error = reward + self.discount_factor * self.state_values[nx, ny] - self.state_values[x, y]
# 更新状态价值(墙壁不更新)
if (x, y) not in self.env.wall_positions:
self.state_values[x, y] += self.learning_rate * td_error
# 记录更新历史
self.value_update_history.append({
'episode': self.episode_count,
'state': state,
'td_error': td_error,
'new_value': self.state_values[x, y]
})
return td_error
def get_evaluation_results(self) -> Dict[str, Any]:
"""
获取策略评估结果
Returns:
Dict: 包含状态价值函数和评估统计
"""
return {
'state_values': self.state_values.copy(),
'episode_count': self.episode_count,
'total_steps': self.total_steps,
'avg_reward': np.mean(self.episode_rewards) if self.episode_rewards else 0.0,
'avg_steps': np.mean(self.episode_steps) if self.episode_steps else 0.0,
'success_rate': self._calculate_success_rate(),
}
def _calculate_success_rate(self) -> float:
"""计算成功率"""
if not self.episode_rewards:
return 0.0
successful_episodes = sum(1 for reward in self.episode_rewards if reward > 0)
return successful_episodes / len(self.episode_rewards)
def visualize_state_values(self,
cmap: str = 'RdYlGn',
figsize: Tuple[int, int] = (12, 8)) -> None:
"""
可视化状态价值函数
Args:
cmap: 颜色映射
figsize: 图形大小
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# 左侧:热图
self._plot_value_heatmap(ax1, cmap)
# 右侧:数值表格
self._plot_value_table(ax2)
plt.suptitle(f'TD(0) 策略评估结果 (共{self.episode_count}轮)', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()
def _plot_value_heatmap(self, ax, cmap: str) -> None:
"""绘制价值热图"""
# 创建掩码(隐藏墙壁)
mask = np.zeros_like(self.state_values, dtype=bool)
for wall_x, wall_y in self.env.wall_positions:
mask[wall_x, wall_y] = True
# 创建带掩码的数组
display_values = np.ma.array(self.state_values, mask=mask)
# 绘制热图
im = ax.imshow(display_values, cmap=cmap, interpolation='nearest', aspect='auto')
plt.colorbar(im, ax=ax, label='状态价值 V(s)')
# 标记特殊位置
self._mark_special_positions(ax)
ax.set_title('状态价值热图')
ax.set_xlabel('列 (y)')
ax.set_ylabel('行 (x)')
ax.set_xticks(range(self.env.cols))
ax.set_yticks(range(self.env.rows))
def _mark_special_positions(self, ax) -> None:
"""标记特殊位置"""
# 开始位置
start_x, start_y = self.env.start_position
ax.scatter(start_y, start_x, color='gold', s=200, marker='*',
edgecolor='black', linewidth=1, label='起点')
# 目标位置
goal_x, goal_y = self.env.goal_position
ax.scatter(goal_y, goal_x, color='green', s=200, marker='*',
edgecolor='black', linewidth=1, label='目标')
# 陷阱位置
for trap_x, trap_y in self.env.trap_positions:
ax.scatter(trap_y, trap_x, color='red', s=100, marker='x',
linewidth=2, label='陷阱' if trap_x == 0 and trap_y == 8 else "")
ax.legend(loc='upper right')
def _plot_value_table(self, ax) -> None:
"""绘制价值数值表格"""
# 创建表格数据
table_data = []
for x in range(self.env.rows):
row_data = []
for y in range(self.env.cols):
if (x, y) in self.env.wall_positions:
row_data.append('██')
else:
row_data.append(f'{self.state_values[x, y]:.2f}')
table_data.append(row_data)
# 创建表格
table = ax.table(cellText=table_data,
cellLoc='center',
loc='center',
colWidths=[0.1] * self.env.cols)
# 设置表格样式
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 1.5)
# 使用颜色常量(使用matplotlib颜色名称或十六进制代码,避免除法运算)
COLORS = {
'start': 'gold', # 金色 - 开始位置
'goal': 'limegreen', # 安全绿 - 目标位置
'trap': 'crimson', # 警示红 - 陷阱位置
'blue_portal': 'deepskyblue', # 科技蓝 - 蓝色传送门
'purple_portal': 'mediumpurple', # 神秘紫 - 紫色传送门
'wall': 'dimgray', # 暗灰石墙 - 墙壁
'positive': 'lightgreen', # 浅绿 - 正值状态
'negative': 'lightcoral', # 浅红 - 负值状态
}
# 设置特殊单元格颜色
for (x, y), cell in table.get_celld().items():
if x == 0: # 标题行
continue
actual_x, actual_y = x, y
pos = (actual_x, actual_y)
# 按照优先级设置颜色(后设置的会覆盖先设置的)
if pos in self.env.wall_positions:
cell.set_facecolor(COLORS['wall'])
cell.set_text_props(weight='bold')
elif pos == self.env.start_position:
cell.set_facecolor(COLORS['start'])
elif pos == self.env.goal_position:
cell.set_facecolor(COLORS['goal'])
elif pos in self.env.trap_positions:
cell.set_facecolor(COLORS['trap'])
elif self.state_values[actual_x, actual_y] > 0:
# 对于有正价值的状态使用浅绿色
cell.set_facecolor(COLORS['positive'])
elif self.state_values[actual_x, actual_y] < 0:
# 对于有负价值的状态使用浅红色
cell.set_facecolor(COLORS['negative'])
ax.set_title('状态价值表格')
ax.axis('off')
class GridWorldExperiment:
"""
网格世界实验管理器
用于管理和执行策略评估实验
"""
def __init__(self,
env: GridWorldEnvironment,
policy: Optional[Policy] = None) -> None:
"""
初始化实验
Args:
env: 环境实例
policy: 策略实例,如果为None则使用随机策略
"""
self.env = env
# 使用默认策略或传入的策略
self.policy = policy
# 创建策略评估器
self.evaluator = TD0PolicyEvaluator(env, self.policy)
# 性能统计
self.moving_avg_rewards: List[float] = []
self.moving_avg_steps: List[int] = []
self.success_rates: List[float] = []
def run_evaluation(self,
num_episodes: int = 1000,
progress_interval: int = 100) -> Dict[str, Any]:
"""
运行策略评估实验
Args:
num_episodes: 评估轮数
progress_interval: 进度打印间隔
Returns:
Dict: 评估结果
"""
print(f"开始策略评估,总轮数: {num_episodes}")
print("=" * 60)
for episode in range(num_episodes):
# 单轮评估
reward, steps = self.evaluator.evaluate_episode()
# 更新性能统计
self._update_performance_metrics(episode, reward, steps)
# 定期输出进度
if (episode + 1) % progress_interval == 0:
self._print_progress(episode + 1)
print("\n策略评估完成!")
print("=" * 60)
return self._compile_results()
def _update_performance_metrics(self,
episode: int,
reward: float,
steps: int) -> None:
"""更新性能指标"""
window_size = 50
# 滑动平均奖励
if episode >= window_size - 1:
recent_rewards = self.evaluator.episode_rewards[episode-window_size+1:episode+1]
self.moving_avg_rewards.append(np.mean(recent_rewards))
# 成功率(整个历史)
if episode >= 10: # 至少有10轮数据
success_rate = self.evaluator._calculate_success_rate()
self.success_rates.append(success_rate)
def _print_progress(self, episode: int) -> None:
"""打印训练进度"""
# 计算最近100轮的表现
start_idx = max(0, episode - 100)
recent_rewards = self.evaluator.episode_rewards[start_idx:episode]
recent_steps = self.evaluator.episode_steps[start_idx:episode]
avg_reward = np.mean(recent_rewards) if recent_rewards else 0.0
avg_steps = np.mean(recent_steps) if recent_steps else 0.0
success_rate = np.mean([r > 0 for r in recent_rewards]) * 100 if recent_rewards else 0.0
print(f"轮数 {episode:4d} | "
f"平均奖励: {avg_reward:6.3f} | "
f"平均步数: {avg_steps:5.1f} | "
f"成功率: {success_rate:5.1f}%")
def _compile_results(self) -> Dict[str, Any]:
"""编译评估结果"""
results = self.evaluator.get_evaluation_results()
# 添加更多统计信息
results.update({
'moving_avg_rewards': self.moving_avg_rewards,
'moving_avg_steps': self.moving_avg_steps,
'success_rates': self.success_rates,
'policy_type': self.policy.__class__.__name__,
})
return results
def visualize_results(self,
save_path: Optional[str] = None,
figsize: Tuple[int, int] = (15, 10)) -> None:
"""
可视化评估结果
Args:
save_path: 保存图片的路径
figsize: 图形大小
"""
plt.figure(figsize=figsize)
# 1. 奖励曲线
ax1 = plt.subplot(2, 2, 1)
self._plot_reward_curve(ax1)
# 2. 步数曲线
ax2 = plt.subplot(2, 2, 2)
self._plot_step_curve(ax2)
# 3. 成功率曲线
ax3 = plt.subplot(2, 2, 3)
self._plot_success_rate_curve(ax3)
# 4. 状态价值函数
ax4 = plt.subplot(2, 2, 4)
self._plot_state_values(ax4)
plt.suptitle(f'TD(0) 策略评估实验结果 ({self.policy.__class__.__name__})',
fontsize=16, fontweight='bold')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"结果图表已保存至: {save_path}")
plt.show()
def _plot_reward_curve(self, ax) -> None:
"""绘制奖励曲线"""
episodes = range(1, len(self.evaluator.episode_rewards) + 1)
# 原始奖励
ax.plot(episodes, self.evaluator.episode_rewards, 'b-',
alpha=0.3, label='每轮奖励', linewidth=0.5)
# 滑动平均奖励
if self.moving_avg_rewards:
moving_episodes = range(50, len(self.moving_avg_rewards) + 50)
ax.plot(moving_episodes, self.moving_avg_rewards, 'r-',
linewidth=2, label='滑动平均 (窗口=50)')
# 标记紫色传送门激活点
if len(self.evaluator.episode_rewards) > 100:
ax.axvline(x=100, color='purple', linestyle='--',
label='紫色传送门激活', alpha=0.7)
ax.set_xlabel('训练轮数')
ax.set_ylabel('累计奖励')
ax.set_title('奖励曲线')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper left')
def _plot_step_curve(self, ax) -> None:
"""绘制步数曲线"""
episodes = range(1, len(self.evaluator.episode_steps) + 1)
ax.plot(episodes, self.evaluator.episode_steps, 'g-', alpha=0.6)
if self.moving_avg_steps:
moving_episodes = range(50, len(self.moving_avg_steps) + 50)
ax.plot(moving_episodes, self.moving_avg_steps, 'orange',
linewidth=2, label='平均步数 (窗口=50)')
ax.legend()
ax.set_xlabel('训练轮数')
ax.set_ylabel('每轮步数')
ax.set_title('步数曲线')
ax.grid(True, alpha=0.3)
def _plot_success_rate_curve(self, ax) -> None:
"""绘制成功率曲线"""
if self.success_rates:
episodes = range(10, len(self.success_rates) + 10)
ax.plot(episodes, self.success_rates, 'm-', linewidth=2)
if len(self.success_rates) > 100:
ax.axvline(x=100, color='purple', linestyle='--', alpha=0.7)
ax.set_xlabel('训练轮数')
ax.set_ylabel('成功率')
ax.set_title('成功率曲线')
ax.set_ylim([0, 1.05])
ax.grid(True, alpha=0.3)
def _plot_state_values(self, ax) -> None:
"""绘制状态价值函数"""
# 使用策略评估器的可视化方法
self.evaluator._plot_value_heatmap(ax, 'RdYlGn')
ax.set_title('最终状态价值')
def run_experiment(policy_type: str = 'random',
num_episodes: int = 500) -> None:
"""
运行完整的策略评估实验
Args:
policy_type: 策略类型 ('random' 或 'right_preferential')
num_episodes: 评估轮数
"""
print("=" * 60)
print("网格世界强化学习实验 - TD(0)策略评估")
print("=" * 60)
# 1. 创建环境
print("\n1. 创建网格世界环境...")
env = GridWorldEnvironment()
env.render(figsize=(12, 8))
# 2. 创建策略
print(f"\n2. 创建{policy_type}策略...")
policy = RightPreferentialPolicy(env, right_prob=0.7)
# 3. 创建并运行实验
print(f"\n3. 创建实验并运行{num_episodes}轮评估...")
experiment = GridWorldExperiment(env, policy)
results = experiment.run_evaluation(num_episodes=num_episodes)
# 4. 显示评估结果摘要
print("\n4. 评估结果摘要:")
print("-" * 40)
summary_items = [
('总轮数', results['episode_count']),
('总步数', results['total_steps']),
('平均奖励', f"{results['avg_reward']:.3f}"),
('平均步数', f"{results['avg_steps']:.1f}"),
('成功率', f"{results['success_rate']:.1%}"),
('策略类型', results['policy_type']),
]
for label, value in summary_items:
print(f" {label:<10}: {value}")
print("-" * 40)
# 5. 可视化结果
print("\n5. 生成可视化图表...")
experiment.visualize_results(save_path='td0_policy_evaluation.png')
# 6. 可视化状态价值函数
print("\n6. 显示状态价值函数...")
experiment.evaluator.visualize_state_values()
print("\n实验完成!")
print("=" * 60)
def main():
"""主函数"""
# 设置随机种子以保证可复现性
np.random.seed(42)
random.seed(42)
# 运行实验
run_experiment(policy_type='random', num_episodes=500)
if __name__ == "__main__":
main()




