详细笔记:蒙特卡洛树搜索 (MCTS) 数学基础与实现
蒙特卡洛树搜索 (MCTS) 是一种在各种决策过程中寻找最优决策的启发式搜索算法。它特别适用于那些状态空间巨大、难以进行穷举搜索的问题,例如围棋 (Go)、国际象棋、以及许多其他复杂游戏和规划任务。MCTS 的巧妙之处在于它将传统树搜索的系统性与蒙特卡洛模拟的随机评估能力结合起来,以有限的计算资源高效地探索最有希望的决策路径。
一、MCTS 概述:像聪明人一样"试错"
想象一下你在玩一个非常复杂的游戏,比如围棋。每一步都有很多种选择,而且要看很多步之后才能知道某个选择好不好。如果你想把所有可能性都算一遍,那几乎是不可能的。
MCTS 提供了一种更聪明的方法:
- "试几手" (模拟):它不会去计算所有可能性,而是随机地(或者用一些简单的规则)"快进"游戏,看看如果这么走,最后可能会是什么结果。
- "吸取教训" (评估):根据这些"快进"的结果(赢了还是输了),MCTS 会更新对当前局势下各个选择的"好感度"。
- "重点关注" (选择):它会更倾向于探索那些之前表现好,或者还不太了解但有潜力的选择。
MCTS 是一个 "随时可用算法" (anytime algorithm)。这意味着你给它越多的计算时间(让它"试错"的次数越多),它找到的决策就越可能接近最优。即使时间有限,它也能给出一个当前看来最好的答案。
二、MCTS 核心概念与数学基础:构建决策智慧树
MCTS 的核心是在内存中逐步构建和维护一棵搜索树 。树的节点代表游戏中的某个状态 (比如棋盘的某个局面),而连接节点的边则代表从一个状态到另一个状态所执行的行动(比如在棋盘上落子)。
算法的每一次迭代(可以理解为一次"思考周期")都包含四个核心步骤,它们不断循环,直到计算资源耗尽(例如,达到预设的迭代次数或运行时间):
- Selection (选择): 挑选最有潜力的分支。
- Expansion (扩展): 在选中的分支上探索新的可能性。
- Simulation (模拟 / Rollout): 对新可能性进行快速评估。
- Backpropagation (回传): 将评估结果反馈给相关的历史决策。
让我们深入了解这些步骤和它们背后的数学原理。
1. 搜索树节点 (Search Tree Node):决策点的档案
MCTS 搜索树中的每个节点都像是一个决策点的档案,记录着关键信息:
- 状态 (State, s s s): 该节点代表的游戏局面。
- 父节点 (Parent): 指向上一层决策的节点。
- 子节点 (Children) : 一个列表或字典,存储从当前状态 s s s 执行各个合法行动 a a a 后能到达的子状态节点。
- 访问次数 (Visit Count, N N N) : 该节点(或从父节点到该节点的行动)被探索路径经过的总次数。可以记为 N ( s ) N(s) N(s) (节点的访问次数) 或 N ( s p , a ) N(s_p, a) N(sp,a) (从父状态 s p s_p sp 执行行动 a a a 的次数)。
- 总价值 (Total Value, Q Q Q) : 通过该节点(或行动)进行的模拟所获得的累计奖励总和。可以记为 Q ( s ) Q(s) Q(s) 或 Q ( s p , a ) Q(s_p, a) Q(sp,a)。这里的"价值"通常是游戏结果的数字表示,例如:赢为 +1,输为 -1,平局为 0。
重要区分:
- N ( s ) N(s) N(s) 和 Q ( s ) Q(s) Q(s) 通常指与状态 s s s 关联的统计量,即所有以状态 s s s 作为起点的模拟的统计。
- N ( s p , a ) N(s_p, a) N(sp,a) 和 Q ( s p , a ) Q(s_p, a) Q(sp,a) 通常指与行动 a a a (在父状态 s p s_p sp 下执行) 关联的统计量。
在实现中,这些信息可以存储在子节点上(代表通过该子节点状态的统计),或者存储在父节点指向子节点的边上。本笔记的实现将 N N N 和 Q Q Q 主要存储在子节点上,代表该子节点状态的整体访问和价值。
2. 探索与利用的平衡 (Exploration vs. Exploitation):智慧的权衡
MCTS 的一个核心挑战是在以下两者之间取得平衡:
- 利用 (Exploitation): 选择当前已知能带来最好平均回报的行动。
- 探索 (Exploration): 选择那些访问次数较少、潜力未知但可能更好的行动。
在 Selection (选择) 步骤中,MCTS 通常使用一种名为 UCB1 (Upper Confidence Bound 1) 的策略(或其变种)来做这个权衡。对于一个父节点 s p s_p sp,选择其子节点 s c s_c sc(通过行动 a a a 到达)的 UCB1 得分计算如下:
U ( s c ) = Q ( s c ) N ( s c ) ⏟ 利用项 (Exploitation) + c ⋅ ln N ( s p ) N ( s c ) ⏟ 探索项 (Exploration) U(s_c) = \underbrace{\frac{Q(s_c)}{N(s_c)}}{\text{利用项 (Exploitation)}} + \underbrace{c \cdot \sqrt{\frac{\ln N(s_p)}{N(s_c)}}}{\text{探索项 (Exploration)}} U(sc)=利用项 (Exploitation) N(sc)Q(sc)+探索项 (Exploration) c⋅N(sc)lnN(sp)
让我们拆解这个公式:
-
Q ( s c ) N ( s c ) \frac{Q(s_c)}{N(s_c)} N(sc)Q(sc) (利用项) : 这是子节点 s c s_c sc 的平均价值或平均回报。它代表了根据已有经验,选择这个子节点有多好。
- Q ( s c ) Q(s_c) Q(sc) 是通过子节点 s c s_c sc 的所有模拟获得的累计价值。
- N ( s c ) N(s_c) N(sc) 是子节点 s c s_c sc 被访问的总次数。
- 注意 : 这里的 Q ( s c ) Q(s_c) Q(sc) 通常是从子节点 s c s_c sc 的对手 角度看的价值的相反数,或者更准确地说,是父节点 s p s_p sp 的当前玩家选择行动 a a a 达到子节点 s c s_c sc 后,期望获得的回报。如果 Q ( s c ) Q(s_c) Q(sc) 存储的是子节点 s c s_c sc 状态下当前玩家的累计价值,那么父节点在选择时,如果父子节点的玩家不同,就需要对 Q ( s c ) Q(s_c) Q(sc) 取负号。我们的实现中,
value
存储的是从该节点玩家视角的累计回报,所以在select_child
中会用-child.value / child.visits
。
-
c ⋅ ln N ( s p ) N ( s c ) c \cdot \sqrt{\frac{\ln N(s_p)}{N(s_c)}} c⋅N(sc)lnN(sp) (探索项):
- N ( s p ) N(s_p) N(sp) 是父节点 s p s_p sp 的总访问次数。父节点被访问得越多,我们对它的子节点的评估就越依赖于它们的具体表现,同时也更有信心去探索那些访问次数少的子节点。 ln N ( s p ) \ln N(s_p) lnN(sp) 使得探索的倾向随着父节点访问次数的增加而缓慢增长。
- N ( s c ) N(s_c) N(sc) 是子节点 s c s_c sc 的总访问次数。如果一个子节点 N ( s c ) N(s_c) N(sc) 很小(访问很少),那么分母小,这一项就大,从而鼓励选择这个子节点进行探索。
- 随着 N ( s c ) N(s_c) N(sc) 增加,探索项会减小,探索的优先级降低。
-
c c c (探索常数): 这是一个超参数,用来平衡利用项和探索项的权重。
- c > 0 c > 0 c>0。如果 c = 0 c=0 c=0,则 MCTS 只会利用,不进行探索。
- c c c 越大,算法越倾向于探索未知的子节点。
- 理论上 c = 2 c = \sqrt{2} c=2 是一个常见的选择,但在实践中可能需要根据具体问题进行调整。
特殊情况 : 如果一个子节点从未被访问过 ( N ( s c ) = 0 N(s_c) = 0 N(sc)=0),它的探索项理论上是无穷大。在实现时,这通常意味着未被访问过的子节点会被优先选择(在 Expansion 步骤中被创建和首次访问)。
3. 四个核心步骤详解
MCTS 的一次迭代(或一次"思考")由以下四个步骤构成:
-
Selection (选择):
- 目标: 从根节点(当前真实游戏状态)开始,沿着树向下走,直到找到一个"有潜力"的叶子节点。
- 过程 :
- 从当前节点开始,如果它不是叶子节点(即有子节点),则使用 UCB1 公式(或其他选择策略)计算其所有子节点的得分。
- 选择得分最高的子节点,移动到该子节点,并重复此过程。
- 结束条件 : 直到到达以下任一类型的节点:
- 未完全扩展的节点: 该节点代表一个游戏状态,但它的一些合法行动还没有在搜索树中生成对应的子节点。
- 叶子节点 (Terminal Node): 该节点代表游戏结束的状态(分出胜负或平局)。
- 叶子节点 (Non-Terminal Leaf): 该节点尚未扩展任何子节点(在树的边界上,但游戏未结束)。
-
Expansion (扩展):
- 目标: 如果 Selection 步骤停止在一个未完全扩展且游戏未结束的节点,就在这里为树添加一个新的子节点。
- 过程 :
- 从该节点选择一个之前未被探索过的合法行动 (例如,随机选择一个未生成子节点的行动)。
- 执行这个行动,得到一个新的游戏状态。
- 为这个新状态创建一个新的子节点,并将其添加到搜索树中,作为当前节点的子节点。
- 新创建的节点通常会被初始化: N = 0 , Q = 0 N=0, Q=0 N=0,Q=0。
-
Simulation (模拟 / Rollout):
- 目标: 快速估计新扩展节点(或 Selection 选中的叶子节点)的价值。
- 过程 :
- 从 Expansion 步骤新创建的节点(或者,如果 Selection 到达的是一个已完全扩展但非终止的叶子节点,则从该节点)所代表的状态开始。
- 进行一次快速、通常是随机的模拟游戏,直到游戏结束。这意味着在这个阶段,选择行动通常不依赖于复杂的评估,而是随机选择合法行动,或者使用一个非常轻量级的启发式策略(例如,在棋类游戏中尝试吃子或避免被吃)。
- 结果: 记录模拟游戏的最终结果(例如,发起模拟的玩家是赢、输还是平局)。这个结果会被转换成一个数值奖励(例如,赢为+1,输为-1,平局为0)。
-
Backpropagation (回传 / 更新):
- 目标: 将 Simulation 阶段获得的模拟结果(奖励)反馈给从模拟开始节点一直到根节点的路径上的所有节点。
- 过程 :
- 从模拟开始的那个节点(即 Expansion 步骤创建的新节点,或 Selection 选中的叶子节点)开始,向上回溯到根节点。
- 路径上的每一个节点 (包括它自己):
- 其访问次数 N N N 增加 1。
- 其总价值 Q Q Q 累加 Simulation 得到的奖励 。
- 重要 : Q Q Q 值的更新需要考虑玩家视角。如果模拟结果是针对发起模拟的玩家 P 来说的(例如 P 赢了,奖励为+1),那么在回传时:
- 如果路径上的某个节点代表轮到玩家 P 行动,那么该节点的 Q Q Q 值增加奖励 (+1)。
- 如果路径上的某个节点代表轮到玩家 P 的对手行动,那么该节点的 Q Q Q 值增加负奖励 (-1),因为对手的收益就是我方的损失。
- 平局则对双方都是 0。
- 重要 : Q Q Q 值的更新需要考虑玩家视角。如果模拟结果是针对发起模拟的玩家 P 来说的(例如 P 赢了,奖励为+1),那么在回传时:
这四个步骤 (Selection, Expansion, Simulation, Backpropagation) 构成一次完整的 MCTS 迭代。算法会重复执行成千上万次这样的迭代,不断地扩展和优化搜索树。
4. MCTS 算法流程 (伪代码)
Function MCTS_Search(root_state, num_iterations, exploration_constant_C):
// 创建根节点,代表游戏的初始状态
root_node = Create_Node(state=root_state, parent=null)
For i from 1 to num_iterations:
current_node = root_node
// 1. Selection (选择)
// 持续向下选择,直到遇到叶子节点或未完全扩展的节点
While current_node is fully_expanded AND current_node is not terminal:
current_node = Select_Best_Child_UCB1(current_node, exploration_constant_C)
// 2. Expansion (扩展)
// 如果当前节点不是终止节点且没有被完全扩展,则扩展一个新的子节点
If current_node is not terminal AND current_node is not fully_expanded:
unexplored_action = Choose_Unexplored_Action(current_node)
new_state = Apply_Action(current_node.state, unexplored_action)
current_node = Add_Child(parent=current_node, action=unexplored_action, state=new_state)
// 'current_node' 现在是新扩展的子节点
// 3. Simulation (模拟)
// 从 'current_node' (新扩展的节点或Selection选中的叶子节点) 开始进行随机模拟
// simulation_result 是从 current_node.state 的当前玩家视角看的游戏结果 (+1, -1, 0)
simulation_result = Simulate_Random_Playout(current_node.state)
// 4. Backpropagation (回传)
// 从 'current_node' 开始,向上回传结果到根节点
node_for_backprop = current_node
While node_for_backprop is not null:
node_for_backprop.visits += 1
// 更新价值时,要考虑回传到的节点的当前玩家与模拟结果的玩家是否一致
// 如果 simulation_result 是针对发起模拟的玩家的,
// 那么对于路径上与发起模拟玩家相同的节点,value += simulation_result
// 对于路径上对手的节点,value -= simulation_result (或 += -simulation_result)
// 或者,更简单地,如果simulation_result是游戏最终赢家 (1, 2, 或 3表示平局)
// 那么在每个节点,根据该节点的玩家判断是+1, -1, 还是0
Update_Value(node_for_backprop, simulation_result)
node_for_backprop = node_for_backprop.parent
// 所有迭代完成后,从根节点的子节点中选择最终的行动
// 通常选择访问次数最多的子节点对应的行动,这被认为是最稳健的选择
best_action = Get_Best_Action_From_Root(root_node, criteria="most_visits") // 或 "highest_value_ratio"
Return best_action
三、MCTS 代码实现 (Python):以井字棋为例
为了更好地理解 MCTS,我们将用一个简单的游戏------井字棋 (Tic-Tac-Toe)------来实现它。井字棋虽然简单,但足以展示 MCTS 的核心机制。
1. 井字棋环境 (Tic-Tac-Toe Environment)
首先,我们需要一个类来表示井字棋的游戏状态和规则。
python
import numpy as np
import math
import random
import time # 用于控制MCTS运行时间
class TicTacToeBoard:
def __init__(self):
# 棋盘: 3x3, 0: 空, 1: 玩家1 (X), 2: 玩家2 (O)
self.board = np.zeros((3, 3), dtype=int)
self.current_player = 1 # 玩家1先手
self.winner = 0 # 0: 游戏进行中, 1: 玩家1赢, 2: 玩家2赢, 3: 平局
def get_board_tuple(self):
# 返回棋盘状态的元组形式,用于哈希
return tuple(self.board.flatten())
def is_game_over(self):
if self.winner != 0: # 如果已经有结果,直接返回
return True
# 检查行、列、对角线是否有玩家获胜
for player in [1, 2]:
# 检查行
for r in range(3):
if np.all(self.board[r, :] == player):
self.winner = player
return True
# 检查列
for c in range(3):
if np.all(self.board[:, c] == player):
self.winner = player
return True
# 检查对角线
if np.all(np.diag(self.board) == player) or \
np.all(np.diag(np.fliplr(self.board)) == player):
self.winner = player
return True
# 检查是否平局 (棋盘已满且无胜者)
if np.all(self.board != 0) and self.winner == 0:
self.winner = 3 # 3 代表平局
return True
return False # 游戏未结束
def get_legal_actions(self):
if self.is_game_over():
return [] # 游戏结束则无合法行动
# 合法行动是棋盘上所有空位 (row, col)
return [(r, c) for r in range(3) for c in range(3) if self.board[r, c] == 0]
def apply_action(self, action):
# 创建一个新的棋盘状态副本,以保持原状态不变 (MCTS中常用)
new_board_state = TicTacToeBoard()
new_board_state.board = np.copy(self.board)
new_board_state.current_player = self.current_player # 之后会切换
r, c = action
if new_board_state.board[r, c] != 0:
raise ValueError(f"Invalid action: cell {action} is not empty.")
new_board_state.board[r, c] = self.current_player # 当前玩家落子
# 切换到下一个玩家
new_board_state.current_player = 3 - self.current_player # (1 -> 2, 2 -> 1)
# 检查新状态下游戏是否结束 (这将更新 new_board_state.winner)
new_board_state.is_game_over()
return new_board_state
def get_reward_for_player(self, player_perspective):
# 从 player_perspective 的视角获取游戏结果的奖励
if not self.is_game_over():
return 0 # 游戏未结束,通常奖励为0 (除非有中间奖励)
if self.winner == player_perspective:
return 1 # 赢了
elif self.winner == 3: # 平局
return 0
elif self.winner != 0: # 输了 (另一个玩家赢了)
return -1
return 0 # 理论上不应到达这里
def __eq__(self, other):
# 用于比较两个棋盘状态是否相同 (MCTS节点查找需要)
return isinstance(other, TicTacToeBoard) and \
np.array_equal(self.board, other.board) and \
self.current_player == other.current_player
def __hash__(self):
# 用于将棋盘状态作为字典的键 (MCTS节点查找需要)
# 状态不仅包括棋盘,还包括当前轮到谁
return hash((self.get_board_tuple(), self.current_player))
def __str__(self):
chars = {0: ' ', 1: 'X', 2: 'O'}
s = " 0 1 2\n"
for r in range(3):
s += f"{r} " + "|".join([chars[self.board[r, c]] for c in range(3)]) + "\n"
if r < 2:
s += " -+-+-\n"
s += f"Current Player: {chars[self.current_player] if not self.is_game_over() else 'None (Game Over)'}\n"
if self.is_game_over():
if self.winner == 1: s += "Winner: X\n"
elif self.winner == 2: s += "Winner: O\n"
elif self.winner == 3: s += "Result: Draw\n"
return s
def __repr__(self):
return self.__str__()
2. MCTS 节点类 (MCTS Node Class)
这个类代表搜索树中的一个节点。
python
class MCTSNode:
def __init__(self, state: TicTacToeBoard, parent=None, action_that_led_to_this_state=None):
self.state: TicTacToeBoard = state # 该节点代表的游戏状态
self.parent: MCTSNode = parent # 父节点
self.action_that_led_to_this_state = action_that_led_to_this_state # 到达此状态的行动
self.children: list[MCTSNode] = [] # 子节点列表
# _unexplored_actions 存储的是从当前 state 出发,尚未在 self.children 中创建节点的合法行动
self._unexplored_actions = self.state.get_legal_actions() # 初始化为所有合法行动
self.visits: int = 0 # N: 该节点被访问的次数
self.value: float = 0.0 # Q: 该节点累计获得的价值 (从该节点玩家的视角)
def is_fully_expanded(self) -> bool:
# 如果所有合法行动都已经被探索(即都有对应的子节点了)
return len(self._unexplored_actions) == 0
def is_terminal_node(self) -> bool:
# 检查该节点代表的状态是否是游戏结束状态
return self.state.is_game_over()
def select_best_child(self, exploration_constant: float) -> 'MCTSNode':
# 使用 UCB1 公式选择子节点
# 仅当节点已完全扩展且非终止时调用
if not self.children:
#raise ValueError("Cannot select child from a node with no children.") # 或者返回None,由调用者处理
return None
best_score = -float('inf')
best_child = None
for child in self.children:
if child.visits == 0:
# 如果一个子节点从未被访问过,理论上其探索项是无穷大
# 实际中,这通常意味着它会在Expansion阶段被选中,而不是在这里
# 但为安全起见,给它一个非常高的分数,或者在主循环中优先处理
# 在标准MCTS中,select_child 通常在节点完全扩展后调用,
# 而节点完全扩展意味着每个子节点至少被访问过一次(在Expansion和Backprop中)
# 如果仍有visits=0的子节点,说明逻辑可能有问题或迭代次数太少
# 为简单起见,这里假设visits > 0,或在实践中优先选择visits=0的节点
# 本实现中,Expansion会确保新节点被访问并回传,所以visits至少为1
# 如果一个孩子真的visits=0(例如,选择策略允许跳过某些孩子),那么它的UCB值应该是无限大
# 这里直接返回这个孩子,因为它有最高的探索优先级
return child # 优先探索未访问的(或极少访问的)
# 利用项: Q(child) / N(child)
# Q(child) 是从 child 节点的当前玩家视角看的价值。
# 但我们是从 parent 节点(当前节点 self)的视角做选择。
# parent 和 child 的当前玩家是不同的(轮流下棋)。
# 所以,child 节点的价值对于 parent 来说是负的。
exploitation_score = -child.value / child.visits # 注意这个负号!
# 探索项: c * sqrt(log(N(parent)) / N(child))
exploration_score = exploration_constant * math.sqrt(math.log(self.visits) / child.visits)
ucb_score = exploitation_score + exploration_score
if ucb_score > best_score:
best_score = ucb_score
best_child = child
if best_child is None and self.children: # 以防万一,例如所有孩子visits=0
return random.choice(self.children)
return best_child
def expand(self) -> 'MCTSNode':
# 从未探索的行动中选择一个,创建新的子节点
if not self._unexplored_actions:
raise RuntimeError("Cannot expand a fully expanded node.")
action = self._unexplored_actions.pop() # 随机选择并移除一个未探索的行动
next_state = self.state.apply_action(action)
new_child_node = MCTSNode(next_state, parent=self, action_that_led_to_this_state=action)
self.children.append(new_child_node)
return new_child_node
def update_stats(self, simulation_result_winner: int):
# 回传时更新节点的访问次数和价值
# simulation_result_winner: 模拟游戏的获胜方 (1, 2, or 3 for draw)
self.visits += 1
# 价值更新基于当前节点的玩家视角
# 如果当前节点的玩家 (self.state.current_player) 是发起模拟的玩家
# 那么模拟结果的输赢对它的影响是直接的
# 注意:这里的 self.state.current_player 是 *轮到谁下下一步* 的玩家。
# 而 value 应该是累积的、对于 *导致了这个状态的那个行动的玩家* 的价值。
# 一个更清晰的方式是:value 累积的是从该节点状态开始的模拟中,
# *该节点状态的当前玩家* 的平均回报。
# 假设 simulation_result_winner 是游戏绝对的赢家 (1, 2, 或 3=平局)
# 我们需要将这个结果转换为当前节点 (self.state) 的当前玩家 (self.state.current_player) 的回报
# 例如,如果 self.state.current_player 是玩家1:
# - 如果 winner 是 1, 玩家1的回报是 +1
# - 如果 winner 是 2, 玩家1的回报是 -1
# - 如果 winner 是 3 (平局), 玩家1的回报是 0
reward = 0
if simulation_result_winner == self.state.current_player:
reward = 1 # 当前节点的玩家赢了模拟
elif simulation_result_winner != 0 and simulation_result_winner != 3: # 对手赢了
reward = -1 # 当前节点的玩家输了模拟
# 平局 (winner=3) 或游戏未结束 (winner=0, 不应在模拟结束时发生) reward = 0
self.value += reward
3. MCTS 类 (MCTS Class)
这是实现 MCTS 算法主要逻辑的类。
python
class MCTS:
def __init__(self, exploration_constant=math.sqrt(2)):
self.exploration_constant = exploration_constant # UCB1的探索常数 C
def search(self, initial_state: TicTacToeBoard, num_iterations=None, time_limit_seconds=None):
"""
执行MCTS搜索。
可以基于迭代次数或时间限制。
"""
self.root = MCTSNode(initial_state)
if num_iterations is None and time_limit_seconds is None:
raise ValueError("Must provide either num_iterations or time_limit_seconds.")
if num_iterations is not None:
for _ in range(num_iterations):
self._perform_one_iteration()
else: # time_limit_seconds is not None
start_time = time.time()
while time.time() - start_time < time_limit_seconds:
self._perform_one_iteration()
# 搜索结束后,选择最佳行动
return self._get_best_action_from_root()
def _perform_one_iteration(self):
# 1. Selection
node = self._select_node(self.root)
# 2. Expansion
# 如果选中的节点不是终止节点,并且可以扩展,则扩展它
if not node.is_terminal_node():
if not node.is_fully_expanded():
node = node.expand() # node 现在是新扩展的子节点
# else: node is fully expanded but not terminal, simulate from here
# 3. Simulation
# simulation_result_winner 是绝对的赢家 (1, 2, 或 3=平局)
simulation_result_winner = self._simulate_random_playout(node.state)
# 4. Backpropagation
self._backpropagate(node, simulation_result_winner)
def _select_node(self, node: MCTSNode) -> MCTSNode:
# 从根节点开始,递归选择或直到叶节点
current_node = node
while not current_node.is_terminal_node():
if not current_node.is_fully_expanded():
return current_node # 返回未完全扩展的节点,下一步进行Expansion
else:
# 如果完全扩展,使用UCB1选择最佳子节点
selected_child = current_node.select_best_child(self.exploration_constant)
if selected_child is None: # 可能发生在迭代初期,根节点还没有子节点
return current_node # 或者如果select_best_child在没有孩子时返回None
current_node = selected_child
return current_node # 到达终止节点或一个无法选择子节点的叶节点
def _simulate_random_playout(self, state: TicTacToeBoard) -> int:
# 从给定状态开始,进行随机模拟直到游戏结束
# 返回获胜方 (1, 2, or 3 for draw)
# 创建一个临时状态用于模拟,不修改原始节点状态
current_simulation_state = TicTacToeBoard()
current_simulation_state.board = np.copy(state.board)
current_simulation_state.current_player = state.current_player
current_simulation_state.winner = state.winner # 继承当前赢家状态
# 如果传入的状态本身就是结束状态,直接返回结果
if current_simulation_state.is_game_over():
return current_simulation_state.winner
while not current_simulation_state.is_game_over():
legal_actions = current_simulation_state.get_legal_actions()
if not legal_actions: # 万一出现意外情况
# 这理论上不应该发生,因为 is_game_over 会先捕获
# 如果棋盘满了但没判断出胜负,is_game_over会设为平局
# 此处可以认为是一种特殊平局或错误状态
return 3 # 假设为平局
random_action = random.choice(legal_actions)
current_simulation_state = current_simulation_state.apply_action(random_action)
return current_simulation_state.winner # 返回最终的赢家
def _backpropagate(self, node: MCTSNode, simulation_result_winner: int):
# 从模拟开始的节点向上回传结果
temp_node = node
while temp_node is not None:
temp_node.update_stats(simulation_result_winner)
temp_node = temp_node.parent
def _get_best_action_from_root(self):
# 从根节点的子节点中选择最佳行动
# 通常选择访问次数最多的子节点 (最稳健 robust child)
# 或者选择价值比例最高的子节点 (可能更激进 max child)
if not self.root.children:
# 如果根节点没有子节点(例如,游戏一开始就结束了,或者迭代次数为0)
# 或者合法动作很少,迭代很少的情况下
legal_actions = self.root.state.get_legal_actions()
if legal_actions: return random.choice(legal_actions) # 没有探索就随机选一个
return None # 没有合法动作
most_visited_child = None
max_visits = -1
for child in self.root.children:
if child.visits > max_visits:
max_visits = child.visits
most_visited_child = child
if most_visited_child:
return most_visited_child.action_that_led_to_this_state
else: # 如果所有孩子访问次数都是0 (例如迭代次数很少)
if self.root.children:
# 随机选一个被扩展的,或者干脆从合法行动里随机选
return random.choice(self.root.children).action_that_led_to_this_state
else: # 没有孩子被扩展
legal_actions = self.root.state.get_legal_actions()
if legal_actions: return random.choice(legal_actions)
return None
4. 使用示例:MCTS vs 人类玩家
python
def play_game():
board = TicTacToeBoard()
mcts_player = MCTS(exploration_constant=math.sqrt(2)) # 可以调整探索常数
player_map = {1: "X (MCTS)", 2: "O (Human)"}
while not board.is_game_over():
print("\nCurrent board:")
print(board)
current_player_name = player_map[board.current_player]
print(f"Turn for {current_player_name}")
if board.current_player == 1: # MCTS's turn
print("MCTS is thinking...")
# MCTS 根据迭代次数或时间来决定思考深度
# action = mcts_player.search(board, num_iterations=10000)
action = mcts_player.search(board, time_limit_seconds=1) # 例如思考1秒
if action is None:
print("MCTS found no action (should not happen in TicTacToe unless game over).")
break
print(f"MCTS (X) chose: {action}")
else: # Human's turn
legal_actions = board.get_legal_actions()
print(f"Legal moves: {legal_actions}")
while True:
try:
move_str = input("Enter your move as 'row,col' (e.g., '1,1' for center): ")
r_str, c_str = move_str.split(',')
action = (int(r_str), int(c_str))
if action in legal_actions:
break
else:
print("Invalid move. Try again.")
except ValueError:
print("Invalid input format. Use 'row,col'.")
board = board.apply_action(action)
print("\n--- Game Over ---")
print(board)
if board.winner == 1:
print("MCTS (X) wins!")
elif board.winner == 2:
print("Human (O) wins!")
elif board.winner == 3:
print("It's a draw!")
else:
print("Game ended with an unexpected winner code.")
if __name__ == "__main__":
play_game()
代码实现的关键点回顾:
TicTacToeBoard
: 提供了游戏的基本框架:状态表示、合法行动、状态转移、结束判断、奖励。注意apply_action
返回新状态,保持不可变性。MCTSNode
:value
: 存储的是从当前节点状态的当前玩家视角看的累计回报。select_best_child
: UCB1公式中,子节点的平均价值child.value / child.visits
前面加了负号,因为父节点和子节点的当前玩家是不同的。父节点希望最大化自己的回报,这对应于子节点(对手回合)回报的最小化。update_stats
: 根据模拟的绝对赢家,计算出对当前节点玩家的奖励 (+1, -1, 0),并更新value
。
MCTS
:_select_node
: 沿着树向下,优先扩展未完全扩展的节点,否则用UCB1选择。_simulate_random_playout
: 纯随机模拟,返回绝对赢家。_backpropagate
: 将绝对赢家信息逐层向上传递,每个节点根据自己的当前玩家更新统计。_get_best_action_from_root
: 通常选择访问次数最多的子动作,这被认为是更稳健的选择。
四、MCTS 与 TRPO/PPO 等策略优化算法的简要对比
虽然 MCTS 和 TRPO (Trust Region Policy Optimization) / PPO (Proximal Policy Optimization) 等算法都用于决策制定,但它们的性质和适用场景有显著差异:
特性 | MCTS (蒙特卡洛树搜索) | TRPO/PPO (策略梯度方法) |
---|---|---|
核心思想 | 在线规划,通过模拟和树搜索评估当前状态下的行动价值。 | 学习一个策略函数 (通常是神经网络),直接输出行动或行动概率。 |
模型需求 | 需要一个前向模型 (Forward Model) :即能够模拟游戏如何从状态 s s s 执行行动 a a a 后到达状态 s ′ s' s′,并判断游戏是否结束、谁赢了。不需要可微分模型。 | 通常是无模型 (Model-Free) 的 (不直接学习环境模型),但需要与环境交互收集经验。策略和价值函数通常是可微分的。 |
学习/规划 | 主要是规划 (Planning) 算法,在给定当前状态时,通过搜索来决定最佳行动。它本身不"学习"一个通用的策略。 | 主要是学习 (Learning) 算法,通过与环境交互和优化目标函数来改进策略参数,目标是学习一个在各种状态下都能良好表现的通用策略。 |
在线/离线 | 在线决策。每次需要做决策时,都从当前状态开始构建/扩展搜索树。 | TRPO/PPO 及其变种可以是在线策略 (On-policy) 或离线策略 (Off-policy),但通常需要大量样本进行训练。 |
数据使用 | 模拟产生的数据用于即时更新树中节点的统计值,指导当前搜索。 | 从与环境交互中收集的轨迹 (episodes) 用于计算策略梯度,更新策略网络。 |
适用场景 | 特别擅长具有明确规则、离散行动空间、可以进行快速模拟的游戏(如棋类、牌类)。AlphaGo 系列是著名成功案例。 | 广泛用于连续控制任务(如机器人)、复杂游戏(作为AI训练算法,学习通用策略),以及其他需要学习行为策略的强化学习问题。 |
探索机制 | 通过 UCB1 等机制在选择阶段平衡探索与利用。 | 策略本身可以是随机的(例如输出行动的概率分布),或者在训练中加入噪声、熵正则化等手段鼓励探索。 |
计算成本 | 决策时的计算成本取决于允许的迭代次数/时间。迭代越多,决策越好但越慢。 | 训练阶段计算成本高。一旦训练完成,决策(推理)通常很快。 |
简单来说:
- MCTS 更像是一个"深思熟虑的棋手",在每一步棋前都会模拟很多可能性来决定当前最好的一步。它需要知道游戏的规则才能进行模拟。
- TRPO/PPO 更像是一个"经验丰富的运动员",通过大量的练习(与环境交互)学习到在各种情况下应该如何行动的直觉(策略网络)。它不一定需要知道规则的细节,只需要知道行动的结果。
在一些高级系统中(如 AlphaZero),MCTS 和深度学习(类似 PPO 中的策略/价值网络)被结合起来:神经网络指导 MCTS 的选择和扩展阶段(代替纯随机模拟或简单启发式),而 MCTS 的搜索结果反过来又用于生成更高质量的数据来训练神经网络,形成一个强大的自我提升循环。
五、总结与展望
MCTS 是一种非常强大且灵活的决策算法,其核心优势在于:
- 无需领域知识启发式: 基本的 MCTS 只需要游戏规则(用于模拟)即可工作,不需要复杂的评估函数。
- 非对称树增长: 它会集中计算资源探索更有希望的分支,而不是均匀扩展整个搜索树。
- 随时可用性: 可以在任何时候停止并返回当前最佳决策,计算时间越长,决策质量通常越高。
- 易于并行化: MCTS 的多次模拟在很大程度上是独立的,这使得它适合并行计算以加速搜索。
MCTS 的成功(尤其是在围棋领域的突破)证明了它处理复杂决策问题的能力。通过与深度学习等技术结合,MCTS 的潜力被进一步放大,使其成为现代人工智能研究中一个持续活跃且富有成果的领域。对于想要解决规划、游戏AI或其他序贯决策问题的开发者和研究者来说,理解和掌握 MCTS 无疑是一项宝贵的技能。