Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫

目录

  • [0 专栏介绍](#0 专栏介绍)
  • [1 Q-Learning算法原理](#1 Q-Learning算法原理)
  • [2 强化学习基本框架](#2 强化学习基本框架)
  • [3 机器人走迷宫算法](#3 机器人走迷宫算法)
    • [3.1 迷宫环境](#3.1 迷宫环境)
    • [3.2 状态、动作和奖励](#3.2 状态、动作和奖励)
    • [3.3 Q-Learning算法实现](#3.3 Q-Learning算法实现)
    • [3.4 完成训练](#3.4 完成训练)
  • [4 算法分析](#4 算法分析)
    • [4.1 Q-Table](#4.1 Q-Table)
    • [4.2 奖励曲线](#4.2 奖励曲线)

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

🚀详情:《Pytorch深度强化学习》


1 Q-Learning算法原理

Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法)介绍到时序差分强化学习是动态规划与蒙特卡洛的折中

Q π ( s t , a t ) = n 次增量 Q π ( s t , a t ) + α ( R t − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ R t + 1 − Q π ( s t , a t ) )    = n 次增量 Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) ⏟ 采样 \begin{aligned}Q^{\pi}\left( s_t,a_t \right) &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( R_t-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+\gamma R_{t+1}-Q^{\pi}\left( s_t,a_t \right) \right) \\\,\, &\xlongequal{n\text{次增量}}{ \underset{\text{采样}}{\underbrace{Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }}}\end{aligned} Qπ(st,at)n次增量 Qπ(st,at)+α(Rt−Qπ(st,at))n次增量 Qπ(st,at)+α(rt+1+γRt+1−Qπ(st,at))n次增量 采样 Qπ(st,at)+α(rt+1+γQπ(st+1,at+1)−Qπ(st,at))

其中 r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) r_{t+1}+\gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) -Q^{\pi}\left( s_t,a_t \right) rt+1+γQπ(st+1,at+1)−Qπ(st,at)称为时序差分误差。基于离轨策略的时序差分强化学习的代表性算法是Q-learning算法,其算法流程如下所示。具体的策略改进算法推导请见之前的文章,本文重点在于应用Q-learning算法解决实际问题

我们先来看看最终实现的效果

训练前

训练后

接下来详细讲解如何一步步实现这个智能体

2 强化学习基本框架

强化学习(Reinforcement Learning, RL)在潜在的不确定复杂环境中,训练一个最优决策 π \pi π指导一系列行动实现目标最优化的机器学习方法。在初始情况下,没有训练数据告诉强化学习智能体并不知道在环境中应该针对何种状态采取什么行动,而是通过不断试错得到最终结果,再反馈修正之前采取的策略,因此强化学习某种意义上可以视为具有"延迟标记信息"的监督学习问题。

强化学习的基本过程是:智能体对环境采取某种行动 a a a,观察到环境状态发生转移 s 0 → s s_0\rightarrow s s0→s,反馈给智能体转移后的状态 s s s和对这种转移的奖赏 r r r。综上所述,一个强化学习任务可以用四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E=⟨S,A,P,R⟩表征

  • 状态空间 S S S:每个状态 s ∈ S s \in S s∈S是智能体对感知环境的描述;
  • 动作空间 A A A:每个动作 a ∈ A a \in A a∈A是智能体能够采取的行动;
  • 状态转移概率 P P P:某个动作 a ∈ A a \in A a∈A作用于处在某个状态 s ∈ S s \in S s∈S的环境中,使环境按某种概率分布 P P P转换到另一个状态;
  • 奖赏函数 R R R:表示智能体对状态 s ∈ S s \in S s∈S下采取动作 a ∈ A a \in A a∈A导致状态转移的期望度,通常 r > 0 r>0 r>0为期望行动, r < 0 r<0 r<0为非期望行动。

所以,程序上也需要依次实现四元组 E = < S , A , P , R > E=\left< S,A,P,R \right> E=⟨S,A,P,R⟩

3 机器人走迷宫算法

3.1 迷宫环境

我们创建的迷宫包含障碍物、起点和终点

python 复制代码
class Maze(tk.Tk, object):
    '''
    * @breif: 迷宫环境类
    * @param[in]: None
    '''    
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.title('maze game')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self.buildMaze()

    '''
    * @breif: 创建迷宫
    '''
    def buildMaze(self):
        self.canvas = tk.Canvas(self, bg='white', height=MAZE_H * UNIT, width=MAZE_W * UNIT)
        # 网格地图
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # 创建原点坐标
        origin = np.array([20, 20])

        # 创建障碍
        barrier_list = [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0),
                        (0, 6), (1, 6), (2, 6), (3, 6), (4, 6), (5, 6), (6, 6),
                        (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (6, 1), (6, 2),
                        (6, 3), (6, 4), (6, 5), (1, 2), (2, 2), (4, 1), (5, 4),
                        (1, 4), (3, 3)]
        self.barriers = [self.creatObject(origin, *index) for index in barrier_list]

        # 创建终点
        self.terminus = self.creatObject(origin, 5, 5, 'blue')

3.2 状态、动作和奖励

机器人的状态可以设置为当前的位置坐标

python 复制代码
s = self.canvas.coords(self.agent)

机器人的动作可以设为上、下左、右

python 复制代码
if action == 0:   # up
      if s[1] > UNIT:
          base_action[1] -= UNIT
  elif action == 1:   # down
      if s[1] < (MAZE_H - 1) * UNIT:
          base_action[1] += UNIT
  elif action == 2:   # right
      if s[0] < (MAZE_W - 1) * UNIT:
          base_action[0] += UNIT
  elif action == 3:   # left
      if s[0] > UNIT:
          base_action[0] -= UNIT

机器人的奖励设置为以下几种:

  • 碰到障碍物:-10分,并进入终止状态
  • 成功到达终点: +50分,并进入终止状态
  • 未到达终点:-1分,能量耗散惩罚,防止机器人原地振荡
python 复制代码
if s_ in [self.canvas.coords(barrier) for barrier in self.barriers]:
   reward = -10
   done = True
   s_ = 'terminal'
elif s_ == self.canvas.coords(self.terminus):
   reward = 50
   done = True
   s_ = 'terminal'
else:
   reward = -1
   done = False

3.3 Q-Learning算法实现

根据算法流程,实现下面的Q-Learning训练函数

python 复制代码
def train(self, env, episodes=1000, reward_curve=[], file=None):
	with tqdm(range(episodes)) as bar:
	    for _ in bar:
	        # 初始化环境和该幕累计奖赏
	        state = env.reset()
	        acc_reward = 0
	        while True:
	            # 刷新环境
	            env.render()
	            # 采样一个动作并进行状态转移
	            action = self.policySample(str(state))
	            next_state, reward, done = env.step(action)
	            acc_reward += reward
	            # 智能体学习策略
	            self.learn(str(state), action, reward, str(next_state))
	            state = next_state
	            if done:
	                reward_curve.append(acc_reward)
	                break
	# 保存策略
	if not file:
	    self.q_table.to_csv(file)
	env.destroy()

3.4 完成训练

训练过程如下所示,完成后保存权重文件

python 复制代码
if __name__ == "__main__":
    env = Maze()
    agent = Agent(actions=list(range(env.n_actions)))
    reward_curve = []

    # 训练智能体
    env.after(100, agent.train, env, 50, reward_curve, './weight/csv')

    # 主循环
    env.mainloop()

4 算法分析

4.1 Q-Table

在Q-Learning算法中,我们需要维护一个Q-Table,用来记录各种状态和动作的价值。Q-Table是一个二维表格,其中每一行表示一个状态,每一列表示一个动作。Q-Table中的值表示某个状态下执行某个动作所获得的回报(或者预期回报)。Q-Table的更新是Q-Learning算法的核心。在每次执行动作后,我们会根据当前状态、执行的动作、获得的奖励和下一个状态,来更新Q-Table中对应的值,更新方式是

Q π ( s t , a t ) = Q π ( s t , a t ) + α ( r t + 1 + γ Q π ( s t + 1 , a t + 1 ) − Q π ( s t , a t ) ) Q^{\pi}\left( s_t,a_t \right) ={ {Q^{\pi}\left( s_t,a_t \right) +\alpha \left( r_{t+1}+{ \gamma Q^{\pi}\left( s_{t+1},a_{t+1} \right) }-Q^{\pi}\left( s_t,a_t \right) \right) }} Qπ(st,at)=Qπ(st,at)+α(rt+1+γQπ(st+1,at+1)−Qπ(st,at))

对应代码

python 复制代码
self.q_table.loc[state, action] += self.lr * (q_target - q_predict)

保存的权重文件正是Q-Table,我们可以直观地看一下,其中0-3指的是上下左右四个动作,每行行首则是状态值,其余数是Q-Value

txt 复制代码
,0,1,2,3
"[45.0, 45.0, 75.0, 75.0]",-3.764746051087998,-4.129632180625153,2.070923999854885,-4.129632180625153
terminal,0.0,0.0,0.0,0.0
"[85.0, 45.0, 115.0, 75.0]",-3.7017636879676745,-3.2427095093971663,6.341493354722148,-2.4376270354451357
"[125.0, 45.0, 155.0, 75.0]",-2.822694674017249,12.009385340227768,-3.10550914130922,-1.7370066390489591
"[125.0, 85.0, 155.0, 115.0]",-1.018256983413196,-2.3765728565289628,19.23732307528551,-2.602996266117196
"[165.0, 85.0, 195.0, 115.0]",-2.063857163563445,27.370237164958994,-0.7307141976318489,0.14330394709222574
"[205.0, 85.0, 235.0, 115.0]",-0.4546075907459214,-0.45498153729692925,-0.490099501,0.3662096391980347
"[165.0, 125.0, 195.0, 155.0]",0.9791630128216775,35.427315495348594,-0.28782126600827374,-1.7383137616441329
"[205.0, 45.0, 235.0, 75.0]",-0.3940399,-0.38288265597631166,-0.3940399,-0.3940399
"[205.0, 125.0, 235.0, 155.0]",-0.31765122402993484,-0.3940399,-0.3940399,1.5298899806741253
...

4.2 奖励曲线

训练过程的奖励曲线如下所示

完整代码联系下方博主名片获取


🔥 更多精彩专栏

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

相关推荐
Robot2513 分钟前
浅谈,华为切入具身智能赛道
人工智能
只怕自己不够好8 分钟前
OpenCV 图像运算全解析:加法、位运算(与、异或)在图像处理中的奇妙应用
图像处理·人工智能·opencv
好看资源平台1 小时前
网络爬虫——综合实战项目:多平台房源信息采集与分析系统
爬虫·python
余生H1 小时前
transformer.js(三):底层架构及性能优化指南
javascript·深度学习·架构·transformer
果冻人工智能1 小时前
2025 年将颠覆商业的 8 大 AI 应用场景
人工智能·ai员工
代码不行的搬运工1 小时前
神经网络12-Time-Series Transformer (TST)模型
人工智能·神经网络·transformer
进击的六角龙1 小时前
深入浅出:使用Python调用API实现智能天气预报
开发语言·python
檀越剑指大厂1 小时前
【Python系列】浅析 Python 中的字典更新与应用场景
开发语言·python
石小石Orz1 小时前
Three.js + AI:AI 算法生成 3D 萤火虫飞舞效果~
javascript·人工智能·算法
罗小罗同学1 小时前
医工交叉入门书籍分享:Transformer模型在机器学习领域的应用|个人观点·24-11-22
深度学习·机器学习·transformer