Reinforcement Learning with Code 【Code 2. Tabular Sarsa】

Reinforcement Learning with Code 【Code 2. Tabular Sarsa】

This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu's Mathematical Foundation of Reinforcement Learning .

This code refers to Mofan's reinforcement learning course.

文章目录

  • [Reinforcement Learning with Code 【Code 2. Tabular Sarsa】](#Reinforcement Learning with Code 【Code 2. Tabular Sarsa】)
    • [2.1 Problem and result](#2.1 Problem and result)
    • [2.2 Environment](#2.2 Environment)
    • [2.3 Tabular Sarsa Algorithm](#2.3 Tabular Sarsa Algorithm)
    • [2.4 Run this main](#2.4 Run this main)
    • [2.5 Check the Q table](#2.5 Check the Q table)
    • Reference

2.1 Problem and result

Please consider the problem that a little mouse (denoted by red block) wants to avoid trap (denoted by black block) to get the cheese (denoted by yellow circle). As the figure shows.

This chapter aims to realize tabular Sarsa algorithm sovle this problem.

2.2 Environment

We use the tkinter package of python to build our environment to interact with agent.

python 复制代码
import numpy as np
import time
import sys
import tkinter as tk
# if sys.version_info.major == 2: # 检查python版本是否是python2
#     import Tkinter as tk
# else:
#     import tkinter as tk


UNIT = 40   # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid width


class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        # Action Space
        self.action_space = ['up', 'down', 'right', 'left'] # action space 
        self.n_actions = len(self.action_space)

        # 绘制GUI
        self.title('Maze env')
        self.geometry('{0}x{1}'.format(MAZE_W * UNIT, MAZE_H * UNIT))   # 指定窗口大小 "width x height"
        self._build_maze()

    def _build_maze(self):
        self.canvas = tk.Canvas(self, bg='white',
                           height=MAZE_H * UNIT,
                           width=MAZE_W * UNIT)     # 创建背景画布

        # create grids
        for c in range(UNIT, MAZE_W * UNIT, UNIT): # 绘制列分隔线
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(UNIT, MAZE_H * UNIT, UNIT): # 绘制行分隔线
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # create origin 第一个方格的中心,
        origin = np.array([UNIT/2, UNIT/2]) 

        # hell1
        hell1_center = origin + np.array([UNIT * 2, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - (UNIT/2 - 5), hell1_center[1] - (UNIT/2 - 5),
            hell1_center[0] + (UNIT/2 - 5), hell1_center[1] + (UNIT/2 - 5),
            fill='black')
        # hell2
        hell2_center = origin + np.array([UNIT, UNIT * 2])
        self.hell2 = self.canvas.create_rectangle(
            hell2_center[0] - (UNIT/2 - 5), hell2_center[1] - (UNIT/2 - 5),
            hell2_center[0] + (UNIT/2 - 5), hell2_center[1] + (UNIT/2 - 5),
            fill='black')

        # create oval 绘制终点圆形
        oval_center = origin + np.array([UNIT*2, UNIT*2])
        self.oval = self.canvas.create_oval(
            oval_center[0] - (UNIT/2 - 5), oval_center[1] - (UNIT/2 - 5),
            oval_center[0] + (UNIT/2 - 5), oval_center[1] + (UNIT/2 - 5),
            fill='yellow')

        # create red rect 绘制agent红色方块,初始在方格左上角
        self.rect = self.canvas.create_rectangle(
            origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
            origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
            fill='red')

        # pack all 显示所有canvas
        self.canvas.pack()


    def get_state(self, rect):
            # convert the coordinate observation to state tuple
            # use the uniformed center as the state such as 
            # |(1,1)|(2,1)|(3,1)|...
            # |(1,2)|(2,2)|(3,2)|...
            # |(1,3)|(2,3)|(3,3)|...
            # |....
            x0,y0,x1,y1 = self.canvas.coords(rect)
            x_center = (x0+x1)/2
            y_center = (y0+y1)/2
            state = ((x_center-(UNIT/2))/UNIT + 1, (y_center-(UNIT/2))/UNIT + 1)
            return state


    def reset(self):
        self.update()
        self.after(500) # delay 500ms
        self.canvas.delete(self.rect)   # delete origin rectangle
        origin = np.array([UNIT/2, UNIT/2])
        self.rect = self.canvas.create_rectangle(
            origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
            origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
            fill='red')
        # return observation 
        return self.get_state(self.rect)   

    

    def step(self, action):
        # agent和环境进行一次交互
        s = self.get_state(self.rect)   # 获得智能体的坐标
        base_action = np.array([0, 0])
        reach_boundary = False
        if action == self.action_space[0]:   # up
            if s[1] > 1:
                base_action[1] -= UNIT
            else: # 触碰到边界reward=-1并停留在原地
                reach_boundary = True

        elif action == self.action_space[1]:   # down
            if s[1] < MAZE_H:
                base_action[1] += UNIT
            else:
                reach_boundary = True   

        elif action == self.action_space[2]:   # right
            if s[0] < MAZE_W:
                base_action[0] += UNIT
            else:
                reach_boundary = True

        elif action == self.action_space[3]:   # left
            if s[0] > 1:
                base_action[0] -= UNIT
            else:
                reach_boundary = True

        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent

        s_ = self.get_state(self.rect)  # next state

        # reward function
        if s_ == self.get_state(self.oval):     # reach the terminal
            reward = 1
            done = True
            s_ = 'success'
        elif s_ == self.get_state(self.hell1): # reach the block
            reward = -1
            s_ = 'block_1'
            done = False
        elif s_ == self.get_state(self.hell2):
            reward = -1
            s_ = 'block_2'
            done = False
        else:
            reward = 0
            done = False
            if reach_boundary:
                reward = -1

        return s_, reward, done

    def render(self):
        time.sleep(0.15)
        self.update()




if __name__ == '__main__':
    def test():
        for t in range(10):
            s = env.reset()
            print(s)
            while True:
                env.render()
                a = 'right'
                s, r, done = env.step(a)
                print(s)
                if done:
                    break
    env = Maze()
    env.after(100, test)      # 在延迟100ms后调用函数test
    env.mainloop()

This part is important that the reward function design is include, which is as follows

reward = { 1 , if reach the cheese − 1 , if reach the trap or reach the boundary 0 , others \text{reward} = \left \{ \begin{aligned} & 1, \quad \text{if reach the cheese} \\ & -1, \quad \text{if reach the trap or reach the boundary} \\ & 0, \quad \text{others} \end{aligned} \right. reward=⎩ ⎨ ⎧1,if reach the cheese−1,if reach the trap or reach the boundary0,others

We need to explan some function of the class Maze.

  • First, the function _build_maze creates the inital maze location.
    In this example we ++use the left up coordination of each grid as the state of each block++.
  • Second, the function get_state converts the coordination of each grid to numerical representation such as ( 1 , 1 ) , ( 1 , 2 ) , ⋯ (1,1),(1,2),\cdots (1,1),(1,2),⋯.
  • Third, the function reset renew the state which means placing the mouse in the original grid.
  • Then, the function step we let the agent interact with envrionment for one step, ang get the reward after the action.
  • Then, the function render controls updating the window.

2.3 Tabular Sarsa Algorithm

python 复制代码
import numpy as np
import pandas as pd


class RL():
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = actions  # action list
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy # epsilon greedy update policy
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table, use the coordinate as the observation

            # self.q_table = self.q_table.append(       # DataFrame.append is invalid
            #     pd.Series(
            #         [0]*len(self.actions),
            #         index=self.q_table.columns,
            #         name=state,
            #     )
            # )

            self.q_table = pd.concat(
                [
                self.q_table,
                pd.DataFrame(
                        data=np.zeros((1,len(self.actions))),
                        columns = self.q_table.columns,
                        index = [state]
                    )
                ]
            )

    def choose_action(self, observation):
        """
            Use the epsilon-greedy method to update policy
        """
        self.check_state_exist(observation)
        # action selection
            # epsilon greedy algorithm
        if np.random.uniform() < self.epsilon:
            
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            # state_action == np.max(state_action) generate bool mask
            # choose best action
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # choose random action
            action = np.random.choice(self.actions)
        return action

    def learn(self, s, a, r, s_):
        pass



class SarsaTable(RL):
    """
        Implement Sarsa algorithm which is on-policy
    """
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        super(SarsaTable,self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'success' :
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

We store the Q-table as a DataFrame of pandas. The explanation of the functions are as follows.

  • First, the function check_state_exist check the existence of one state, if not we append it to the Q-table. This is because once the state-action pair is visited, then we update it into the Q-table.
  • Second, the function choose_action is following the ϵ \epsilon ϵ-greedy algorithm

π ( a ∣ s ) = { 1 − ϵ ∣ A ( s ) ∣ ( ∣ A ( s ) ∣ − 1 ) , for the geedy action ϵ ∣ A ( s ) ∣ , for the other ∣ A ( s ) ∣ − 1 actions \pi(a|s) = \left \{ \begin{aligned} 1 - \frac{\epsilon}{|\mathcal{A}(s)|}(|\mathcal{A(s)}|-1), & \quad \text{for the geedy action} \\ \frac{\epsilon}{|\mathcal{A}(s)|}, & \quad \text{for the other } |\mathcal{A}(s)|-1 \text{ actions} \end{aligned} \right. π(a∣s)=⎩ ⎨ ⎧1−∣A(s)∣ϵ(∣A(s)∣−1),∣A(s)∣ϵ,for the geedy actionfor the other ∣A(s)∣−1 actions

  • Third, the function learn is update the q value as Q-learning algorithm purposed, which relays on the sample ( s t , a t , r t + 1 , s t + 1 , a t + 1 ) \textcolor{red}{(s_t,a_t,r_{t+1},s_{t+1},a_{t+1})} (st,at,rt+1,st+1,at+1). The sample denotes current state, current action, immediate reward, next state and next action respectively.

Sarsa : { q t + 1 ( s t , a t ) = q t ( s t , a t ) − α t ( s t , a t ) [ q t ( s t , a t ) − ( r t + 1 + γ q t ( s t + 1 , a t + 1 ) ) ] q t + 1 ( s , a ) = q t ( s , a ) , for all ( s , a ) ≠ ( s t , a t ) \text{Sarsa} : \left \{ \begin{aligned} \textcolor{red}{q_{t+1}(s_t,a_t)} & \textcolor{red}{= q_t(s_t,a_t) - \alpha_t(s_t,a_t) \Big[q_t(s_t,a_t) - (r_{t+1}+ \gamma \ q_t(s_{t+1},a_{t+1})) \Big]} \\ \textcolor{red}{q_{t+1}(s,a)} & \textcolor{red}{= q_t(s,a)}, \quad \text{for all } (s,a) \ne (s_t,a_t) \end{aligned} \right. Sarsa:⎩ ⎨ ⎧qt+1(st,at)qt+1(s,a)=qt(st,at)−αt(st,at)[qt(st,at)−(rt+1+γ qt(st+1,at+1))]=qt(s,a),for all (s,a)=(st,at)

2.4 Run this main

Run this main script that we can run the all codes.

python 复制代码
from maze_env_custom import Maze
from RL_brain import SarsaTable

MAX_EPISODE = 30


def update():
    for episode in range(MAX_EPISODE):
        # initial observation, observation is the rect's coordiante
        # observation is [x0,y0, x1,y1]
        observation = env.reset()   

        # RL choose action based on observation ['up', 'down', 'right', 'left']
        action = RL.choose_action(str(observation))

        while True:
            # fresh env
            env.render()

            # RL take action and get next observation and reward
            observation_, reward, done = env.step(action)
            

            action_ = RL.choose_action(str(observation_))


            # RL learn from this transition
            RL.learn(str(observation), action, reward, str(observation_), action_)

            # swap observation
            observation = observation_
            action = action_

            # break while loop when end of this episode
            if done:
                break

        # show q_table
        print(RL.q_table)
        print('\n')

    # end of game
    print('game over')
    env.destroy()

if __name__ == "__main__":
    env = Maze()
    RL = SarsaTable(env.action_space)

    env.after(100, update)
    env.mainloop()

2.5 Check the Q table

After a long run we can check the q-table to judge wheter the learning is reasonable. The q-table is as follows:

python 复制代码
                      up      down     right          left
(1.0, 1.0) -6.837352e-02 -0.000135 -0.000266 -2.970185e-02
(2.0, 1.0) -4.901299e-02 -0.000334 -0.000484 -6.039572e-04
(2.0, 2.0) -3.988164e-04 -0.049010 -0.038785 -2.737623e-04
block_1     0.000000e+00  0.049010  0.000000  0.000000e+00
(4.0, 2.0) -2.646359e-04  0.001314 -0.019900 -1.000000e-02
(4.0, 1.0) -4.900994e-02  0.000014 -0.010000 -3.128178e-06
(3.0, 1.0) -2.970450e-02 -0.029433 -0.000516 -2.078845e-04
(1.0, 2.0) -4.933690e-04 -0.000374 -0.000951 -3.940947e-02
block_2    -1.979099e-07  0.000000  0.010000 -1.531800e-07
(1.0, 3.0) -3.525635e-04 -0.000056 -0.010000 -3.940439e-02
(1.0, 4.0) -7.194310e-07 -0.010000  0.000591 -1.990000e-02
(2.0, 4.0) -1.000000e-02 -0.019900  0.012381  0.000000e+00
(3.0, 4.0)  1.654862e-01  0.000000  0.000000  0.000000e+00
(4.0, 4.0)  0.000000e+00  0.000000 -0.010000  0.000000e+00
(4.0, 3.0)  0.000000e+00  0.000000  0.000000  5.851985e-02
success     0.000000e+00  0.000000  0.000000  0.000000e+00

For example, when at the original place if the mouse wants to move up or move left it will reach the boundary and get reward − 1 -1 −1. Hence the state value in q-table is minus.


Reference

赵世钰老师的课程
莫烦ReinforcementLearning course

相关推荐
炭烤玛卡巴卡5 分钟前
初学elasticsearch
大数据·学习·elasticsearch·搜索引擎
oneouto23 分钟前
selenium学习笔记(一)
笔记·学习·selenium
张铁铁是个小胖子33 分钟前
MyBatis学习
java·学习·mybatis
我曾经是个程序员41 分钟前
鸿蒙学习记录之http网络请求
服务器·学习·http
m0_748232391 小时前
WebRTC学习二:WebRTC音视频数据采集
学习·音视频·webrtc
刚学HTML1 小时前
leetcode 05 回文字符串
算法·leetcode
AC使者1 小时前
#B1630. 数字走向4
算法
ROBOT玲玉1 小时前
Milvus 中,FieldSchema 的 dim 参数和索引参数中的 “nlist“ 的区别
python·机器学习·numpy
冠位观测者1 小时前
【Leetcode 每日一题】2545. 根据第 K 场考试的分数排序
数据结构·算法·leetcode
GocNeverGiveUp2 小时前
机器学习2-NumPy
人工智能·机器学习·numpy