蒙特卡洛树搜索方法实践

一、算法概述

蒙特卡洛树搜索是一种用于决策过程的启发式搜索算法,特别适用于具有巨大状态空间的游戏和优化问题。其主要结合了:蒙特卡洛方法:通过随机采样来估算复杂问题的解;树搜索:将决策问题建模为树结构;UCB(Upper Confidence Bound):平衡探索与利用的选择策略。

蒙特卡洛树基于两个基本概念:一是可以使用随机模拟来近似某个行动的真实价值;二是可以有效地利用这些价值将策略调整为最佳优先策略。该算法在对博弈树先前探索结果的引导下,逐步构建一个部分博弈树。这棵树用于估计走法的价值,随着树的构建,这些估计值会变得更加准确。

基本算法包括迭代地构建搜索树,直到达到某个预定义的计算预算(通常的是时间、内存或迭代次数限制)。此时,搜索停止,并返回性能最优的根动作。搜索树中的每个节点代表领域的一个状态,指向子节点的有向链接代表通向后续状态的动作。每次搜索迭代应用四个步骤。

(1)选择:从根节点开始,递归应用子节点选择策略,沿着树向下搜索,直到找到最急需扩展的节点。如果一个节点代表非终止状态且有未访问(即未扩展)的子节点,则该节点是可扩展的。

(2)扩展:根据可用的操作,添加一个(或多个)子节点来扩展树。

(3)模拟:根据默认策略从新节点开始进行模拟,以产生一个结果。

(4)反向传播:模拟结果通过选定的节点"回溯"(即反向传播),以更新这些节点的统计信息。

这些可以分为两种截然不同的政策。

(1)树策略:从搜索树中已有的节点中选择或创建一个叶节点。

(2)默认策略:从给定的非终止状态开始执行该领域的操作,以生成价值估计(模拟)。

反向传播步骤本身不使用策略,而是更新节点统计信息,这些信息会为未来的树策略决策提供依据。这些步骤在伪代码Algorithms1中总结如下:这里,是根节点,对应初始状态是树策略阶段最后到达的节点,对应状态;是从状态开始,用默认策略模拟到终局后获得的奖励。整个MCTS搜索的结果a是指:在根节点的所有子节点中,选择"最优"的那个动作

请注意,文献中对"模拟"这一术语存在不同的解释。一些作者认为它指的是在树策略和默认策略下每次迭代所选择的完整动作序列,而大多数作者认为它仅指使用默认策略所选择的动作序列。在本文中,我们将把"推演"和"模拟"这两个术语理解为"根据默认策略将任务执行至完成",即树策略的选择和扩展步骤完成后所选择的动作。

二、详细案例分析

level代表蒙特卡洛树决策的层级,即程序会连续做几次决策

这个案例实现了一个数值优化游戏:游戏有10轮,每轮可以从中选择一个数字,目标是让累积值尽可能接近0。一共在10轮里做出10个决策,选择出10个数字。

(1)State类------游戏状态

关键特点:(1)value:当前累积值。(2)turn:剩余轮数。(3)moves:已执行的移动序列

移动机制解析:(1)当前轮数越高,移动的影响越大。(2)第1轮可选:,第十轮可选:,早期错误代价更高。

python 复制代码
class State():
	NUM_TURNS = 10
	GOAL = 0
	MOVES=[2,-2,3,-3]
	MAX_VALUE= (5.0*(NUM_TURNS-1)*NUM_TURNS)/2
	num_moves=len(MOVES)
	def __init__(self, value=0, moves=[], turn=NUM_TURNS):
		self.value=value
		self.turn=turn
		self.moves=moves
	def next_state(self):
		nextmove=random.choice([x*self.turn for x  in self.MOVES])
		next=State(self.value+nextmove, self.moves+[nextmove],self.turn-1)
		return next
	def terminal(self):
		if self.turn == 0:
			return True
		return False
	def reward(self):
		r = 1.0-(abs(self.value-self.GOAL)/self.MAX_VALUE)
		return r
	def __hash__(self):
		return int(hashlib.md5(str(self.moves).encode('utf-8')).hexdigest(),16)
	def __eq__(self,other):
		if hash(self)==hash(other):
			return True
		return False
	def __repr__(self):
		s="Value: %d; Moves: %s"%(self.value,self.moves)
		return s
	def node_id(self):
		return str(hash(self))
	def node_info(self):
		return f"Value:{self.value}\nMoves:{self.moves}"

逐行解析每一步,NUM_TURN定义了游戏一共有10轮,GOAL定义了目标的值,MOVES定义了动作空间,MAX_VALUE定义了最大的值。num_moves为动作空间的大小.

第一个初始化目标的value值,turn为轮次,moves为移动空间。next_state目标是从目前节点随机选择下一个节点,并且更新到下个节点的状态。terminal就是终止条件。reward就是奖励函数。hash__和__eq,Python对象才能放进set、当作dict的key,才能高效查重。例如

s1=State(value=0,moves=[2,-2],turn=8),s2==State(value=0,moves=[2,-2],turn=8)有了__eq__和__hash__能够快速判断这个新状态是不是已经作为子节点存在。

(2)Node类------MCTS节点

介绍一下TreePolicy方法,一共分为四种情况。

python 复制代码
class Node():
	def __init__(self, state, parent=None):
		self.visits=1
		self.reward=0.0
		self.state=state
		self.children=[]
		self.parent=parent
	def add_child(self,child_state):
		child=Node(child_state,self)
		self.children.append(child)
	def update(self,reward):
		self.reward+=reward
		self.visits+=1
	def fully_expanded(self, num_moves_lambda):
		num_moves = self.state.num_moves
		if num_moves_lambda != None:
			num_moves = num_moves_lambda(self)
		if len(self.children)==num_moves:
			return True
		return False
	def __repr__(self):
		s="Node; children: %d; visits: %d; reward: %f"%(len(self.children),self.visits,self.reward)
		return s
	def node_info(self):
		return f"N:{self.visits}\nR:{self.reward:.2f}\n{self.state.moves}"

初始化节点,仿真次数为1,奖励为0,状态为状态,子节点为空,父节点。

add_child为添加子节点,它的状态为输入,父节点为当前节点本身。然后将刚创建的子节点加入children离去。

update为更新该节点的奖励,以及访问次数加一次。

fully_expanded为判断该节点是否展开了子节点。

(3)MCTS关键算法

treepolicy:树策略

python 复制代码
def TREEPOLICY(node, num_moves_lambda):
	#a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand first
	while node.state.terminal()==False:
		if len(node.children)==0:
			return EXPAND(node)
		elif random.uniform(0,1)<.5:
			node=BESTCHILD(node,SCALAR)
		else:
			if node.fully_expanded(num_moves_lambda)==False:
				return EXPAND(node)
			else:
				node=BESTCHILD(node,SCALAR)
	return node
def EXPAND(node):
	tried_children=[c.state for c in node.children]
	new_state=node.state.next_state()
	while new_state in tried_children and new_state.terminal()==False:
		new_state=node.state.next_state()
	node.add_child(new_state)
	return node.children[-1]

#current this uses the most vanilla MCTS formula it is worth experimenting with THRESHOLD ASCENT (TAGS)
def BESTCHILD(node,scalar):
	bestscore=0.0
	bestchildren=[]
	for c in node.children:
		exploit=c.reward/c.visits
		explore=math.sqrt(2.0*math.log(node.visits)/float(c.visits))
		score=exploit+scalar*explore
		if score==bestscore:
			bestchildren.append(c)
		if score>bestscore:
			bestchildren=[c]
			bestscore=score
	if len(bestchildren)==0:
		logger.warn("OOPS: no best child found, probably fatal")
	return random.choice(bestchildren)

1.首次访问节点(没有子节点),其动作为直接展开,创建第一个子节点。例如第一轮选择动作20。

2.50%概率利用已有信息,使用UCB公式选择最有希望的子节点,根据之前的模拟结果,发现选择-20的子节点表现最好,就选择它。

3.50%概率继续探索,场景两种情况:未完全展开:还有动作没尝试过,创建新子节点。已完全展开:所有4个动作都试过了,选择最佳的继续向下。

DefaultPolicy:默认策略

python 复制代码
def DEFAULTPOLICY(state):
	while state.terminal()==False:
		state=state.next_state()
	return state.reward()

即如果选择子节点20,然后后面不是有9轮,通过随机模拟的情况得到该节点某种情况的回报值。

Backup:回溯

python 复制代码
def BACKUP(node,reward):
	while node!=None:
		node.visits+=1
		node.reward+=reward
		node=node.parent
	return

根据扩展计算的结果来更新该节点的历史信息,为Bestchild的选择做准备。

(4)数字案例

例如,根节点初始化状态为[],访问次数为0,奖励值函数为0。根据输入的level来判定做决策到第几层。例如level为1,则就选出第一轮最优的个数即可。例如第一次仿真开始,

通过树策略选择30,动作记为[30],通过DefaultPolicy仿真一次选择30这个动作后面10轮的一个轮次,评估奖励函数为0.44,访问次数加1。然后通过Backup将根节点的奖励值加为0.44。然后开启第二轮仿真,此时根据TreePolicy现在有50%的概率是扩展根节点,另外50%是基于现有节点的情况下选择最优节点,然后通过DefaultPolicy继续评估。经过50%的概率筛选节点扩展到[-30],根据DefaultPolicy评估到0.866奖励函数,此时加到根节点的奖励函数值里。开始第三轮仿真,此时过程扩展到[20],评估奖励为0.733,继续添加到根节点的奖励函数值里。然后开始第四轮仿真,此时根据50%的概率是根据目前已经展开的三个节点如[30],[-30],[20]节点去展开,三个节点均被访问了一次,根据公式

然后选择了目前评估函数最高的[-30],在其下面展开节点[-30,18],然后通过DefaultPolicy计算得到0.676此时将这个奖励值即加到根节点,也加到[-30]这个节点,以此类推。

完整代码

python 复制代码
#!/usr/bin/env python
import random
import math
import hashlib
import logging
import argparse
import queue
from graphviz import Digraph


"""
A quick Monte Carlo Tree Search implementation.  For more details on MCTS see See http://pubs.doc.ic.ac.uk/survey-mcts-methods/survey-mcts-methods.pdf

The State is a game where you have NUM_TURNS and at turn i you can make
a choice from an integeter [-2,2,3,-3]*(NUM_TURNS+1-i).  So for example in a game of 4 turns, on turn for turn 1 you can can choose from [-8,8,12,-12], and on turn 2 you can choose from [-6,6,9,-9].  At each turn the choosen number is accumulated into a aggregation value.  The goal of the game is for the accumulated value to be as close to 0 as possible.

The game is not very interesting but it allows one to study MCTS which is.  Some features 
of the example by design are that moves do not commute and early mistakes are more costly.  

In particular there are two models of best child that one can use 
"""

#MCTS scalar.  Larger scalar will increase exploitation, smaller will increase exploration. 
SCALAR=1/(2*math.sqrt(2.0))

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger('MyLogger')


class State():
	NUM_TURNS = 10
	GOAL = 0
	MOVES=[2,-2,3,-3]
	MAX_VALUE= (5.0*(NUM_TURNS-1)*NUM_TURNS)/2
	num_moves=len(MOVES)
	def __init__(self, value=0, moves=[], turn=NUM_TURNS):
		self.value=value
		self.turn=turn
		self.moves=moves
	def next_state(self):
		nextmove=random.choice([x*self.turn for x  in self.MOVES])
		next=State(self.value+nextmove, self.moves+[nextmove],self.turn-1)
		return next
	def terminal(self):
		if self.turn == 0:
			return True
		return False
	def reward(self):
		r = 1.0-(abs(self.value-self.GOAL)/self.MAX_VALUE)
		return r
	def __hash__(self):
		return int(hashlib.md5(str(self.moves).encode('utf-8')).hexdigest(),16)
	def __eq__(self,other):
		if hash(self)==hash(other):
			return True
		return False
	def __repr__(self):
		s="Value: %d; Moves: %s"%(self.value,self.moves)
		return s
	def node_id(self):
		return str(hash(self))
	def node_info(self):
		return f"Value:{self.value}\nMoves:{self.moves}"


class Node():
	def __init__(self, state, parent=None):
		self.visits=1
		self.reward=0.0
		self.state=state
		self.children=[]
		self.parent=parent
	def add_child(self,child_state):
		child=Node(child_state,self)
		self.children.append(child)
	def update(self,reward):
		self.reward+=reward
		self.visits+=1
	def fully_expanded(self, num_moves_lambda):
		num_moves = self.state.num_moves
		if num_moves_lambda != None:
			num_moves = num_moves_lambda(self)
		if len(self.children)==num_moves:
			return True
		return False
	def __repr__(self):
		s="Node; children: %d; visits: %d; reward: %f"%(len(self.children),self.visits,self.reward)
		return s
	def node_info(self):
		return f"N:{self.visits}\nR:{self.reward:.2f}\n{self.state.moves}"

def UCTSEARCH(budget,root,num_moves_lambda = None):
	for iter in range(int(budget)):
		if iter%10000==9999:
			logger.info("simulation: %d"%iter)
			logger.info(root)
		front=TREEPOLICY(root, num_moves_lambda)
		reward=DEFAULTPOLICY(front.state)
		BACKUP(front,reward)
	return BESTCHILD(root,0)

def TREEPOLICY(node, num_moves_lambda):
	#a hack to force 'exploitation' in a game where there are many options, and you may never/not want to fully expand first
	while node.state.terminal()==False:
		if len(node.children)==0:
			return EXPAND(node)
		elif random.uniform(0,1)<.5:
			node=BESTCHILD(node,SCALAR)
		else:
			if node.fully_expanded(num_moves_lambda)==False:
				return EXPAND(node)
			else:
				node=BESTCHILD(node,SCALAR)
	return node

def EXPAND(node):
	tried_children=[c.state for c in node.children]
	new_state=node.state.next_state()
	while new_state in tried_children and new_state.terminal()==False:
		new_state=node.state.next_state()
	node.add_child(new_state)
	return node.children[-1]

#current this uses the most vanilla MCTS formula it is worth experimenting with THRESHOLD ASCENT (TAGS)
def BESTCHILD(node,scalar):
	bestscore=0.0
	bestchildren=[]
	for c in node.children:
		exploit=c.reward/c.visits
		explore=math.sqrt(2.0*math.log(node.visits)/float(c.visits))
		score=exploit+scalar*explore
		if score==bestscore:
			bestchildren.append(c)
		if score>bestscore:
			bestchildren=[c]
			bestscore=score
	if len(bestchildren)==0:
		logger.warn("OOPS: no best child found, probably fatal")
	return random.choice(bestchildren)

def DEFAULTPOLICY(state):
	while state.terminal()==False:
		state=state.next_state()
	return state.reward()

def BACKUP(node,reward):
	while node!=None:
		node.visits+=1
		node.reward+=reward
		node=node.parent
	return

def show_search_tree(root):
    dot = Digraph(comment='Game Search Tree')
    visited = set()
    que = queue.Queue()
    que.put(root)

    while not que.empty():
        node = que.get()
        node_id = str(id(node))
        if node_id in visited:
            continue
        visited.add(node_id)

        # 添加节点
        dot.node(node_id, node.node_info())

        # 添加子节点和边
        for child in node.children:
            child_id = str(id(child))
            dot.node(child_id, child.node_info())
            dot.edge(node_id, child_id)
            que.put(child)

    with open("a.dot", "w", encoding="utf-8") as writer:
        writer.write(dot.source)
    dot.render('search_path', view=False)

if __name__=="__main__":
	parser = argparse.ArgumentParser(description='MCTS research code')
	parser.add_argument('--num_sims', action="store", required=True, type=int)
	parser.add_argument('--levels', action="store", required=True, type=int, choices=range(State.NUM_TURNS+1))
	args=parser.parse_args()

	current_node=Node(State())
	for l in range(args.levels):
		current_node=UCTSEARCH(args.num_sims/(l+1),current_node)
		print("level %d"%l)
		print("Num Children: %d"%len(current_node.children))
		for i,c in enumerate(current_node.children):
			print(i,c)
		print("Best Child: %s"%current_node.state)

		print("--------------------------------")

	# 可视化搜索树结构
	show_search_tree(current_node)
	
	

运行方式终端

python 复制代码
python mcts.py --num_sim 100 --levels 2

如果想知道10步决策的每个决策则将终端改为

python 复制代码
python mcts.py --num_sims 100 --levels 10  

得到最终最优的选择策略。

相关推荐
盛寒13 分钟前
向量与向量组的线性相关性 线性代数
线性代数·算法
学不动CV了4 小时前
C语言32个关键字
c语言·开发语言·arm开发·单片机·算法
小屁孩大帅-杨一凡5 小时前
如何解决ThreadLocal内存泄漏问题?
java·开发语言·jvm·算法
Y1nhl6 小时前
力扣_二叉树的BFS_python版本
python·算法·leetcode·职场和发展·宽度优先
向阳逐梦8 小时前
PID控制算法理论学习基础——单级PID控制
人工智能·算法
2zcode8 小时前
基于Matlab多特征融合的可视化指纹识别系统
人工智能·算法·matlab
Owen_Q8 小时前
Leetcode百题斩-二分搜索
算法·leetcode·职场和发展
UnderTheTime9 小时前
2025 XYD Summer Camp 7.10 筛法
算法
zstar-_9 小时前
Claude code在Windows上的配置流程
笔记·算法·leetcode