Reinforcement Learning with Code 【Code 2. Tabular Sarsa】
This note records how the author begin to learn RL. Both theoretical understanding and code practice are presented. Many material are referenced such as ZhaoShiyu's Mathematical Foundation of Reinforcement Learning .
This code refers to Mofan's reinforcement learning course.
文章目录
- [Reinforcement Learning with Code 【Code 2. Tabular Sarsa】](#Reinforcement Learning with Code 【Code 2. Tabular Sarsa】)
-
- [2.1 Problem and result](#2.1 Problem and result)
- [2.2 Environment](#2.2 Environment)
- [2.3 Tabular Sarsa Algorithm](#2.3 Tabular Sarsa Algorithm)
- [2.4 Run this main](#2.4 Run this main)
- [2.5 Check the Q table](#2.5 Check the Q table)
- Reference
2.1 Problem and result
Please consider the problem that a little mouse (denoted by red block) wants to avoid trap (denoted by black block) to get the cheese (denoted by yellow circle). As the figure shows.
This chapter aims to realize tabular Sarsa algorithm sovle this problem.
2.2 Environment
We use the tkinter
package of python to build our environment to interact with agent.
python
import numpy as np
import time
import sys
import tkinter as tk
# if sys.version_info.major == 2: # 检查python版本是否是python2
# import Tkinter as tk
# else:
# import tkinter as tk
UNIT = 40 # pixels
MAZE_H = 4 # grid height
MAZE_W = 4 # grid width
class Maze(tk.Tk, object):
def __init__(self):
super(Maze, self).__init__()
# Action Space
self.action_space = ['up', 'down', 'right', 'left'] # action space
self.n_actions = len(self.action_space)
# 绘制GUI
self.title('Maze env')
self.geometry('{0}x{1}'.format(MAZE_W * UNIT, MAZE_H * UNIT)) # 指定窗口大小 "width x height"
self._build_maze()
def _build_maze(self):
self.canvas = tk.Canvas(self, bg='white',
height=MAZE_H * UNIT,
width=MAZE_W * UNIT) # 创建背景画布
# create grids
for c in range(UNIT, MAZE_W * UNIT, UNIT): # 绘制列分隔线
x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(UNIT, MAZE_H * UNIT, UNIT): # 绘制行分隔线
x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
# create origin 第一个方格的中心,
origin = np.array([UNIT/2, UNIT/2])
# hell1
hell1_center = origin + np.array([UNIT * 2, UNIT])
self.hell1 = self.canvas.create_rectangle(
hell1_center[0] - (UNIT/2 - 5), hell1_center[1] - (UNIT/2 - 5),
hell1_center[0] + (UNIT/2 - 5), hell1_center[1] + (UNIT/2 - 5),
fill='black')
# hell2
hell2_center = origin + np.array([UNIT, UNIT * 2])
self.hell2 = self.canvas.create_rectangle(
hell2_center[0] - (UNIT/2 - 5), hell2_center[1] - (UNIT/2 - 5),
hell2_center[0] + (UNIT/2 - 5), hell2_center[1] + (UNIT/2 - 5),
fill='black')
# create oval 绘制终点圆形
oval_center = origin + np.array([UNIT*2, UNIT*2])
self.oval = self.canvas.create_oval(
oval_center[0] - (UNIT/2 - 5), oval_center[1] - (UNIT/2 - 5),
oval_center[0] + (UNIT/2 - 5), oval_center[1] + (UNIT/2 - 5),
fill='yellow')
# create red rect 绘制agent红色方块,初始在方格左上角
self.rect = self.canvas.create_rectangle(
origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
fill='red')
# pack all 显示所有canvas
self.canvas.pack()
def get_state(self, rect):
# convert the coordinate observation to state tuple
# use the uniformed center as the state such as
# |(1,1)|(2,1)|(3,1)|...
# |(1,2)|(2,2)|(3,2)|...
# |(1,3)|(2,3)|(3,3)|...
# |....
x0,y0,x1,y1 = self.canvas.coords(rect)
x_center = (x0+x1)/2
y_center = (y0+y1)/2
state = ((x_center-(UNIT/2))/UNIT + 1, (y_center-(UNIT/2))/UNIT + 1)
return state
def reset(self):
self.update()
self.after(500) # delay 500ms
self.canvas.delete(self.rect) # delete origin rectangle
origin = np.array([UNIT/2, UNIT/2])
self.rect = self.canvas.create_rectangle(
origin[0] - (UNIT/2 - 5), origin[1] - (UNIT/2 - 5),
origin[0] + (UNIT/2 - 5), origin[1] + (UNIT/2 - 5),
fill='red')
# return observation
return self.get_state(self.rect)
def step(self, action):
# agent和环境进行一次交互
s = self.get_state(self.rect) # 获得智能体的坐标
base_action = np.array([0, 0])
reach_boundary = False
if action == self.action_space[0]: # up
if s[1] > 1:
base_action[1] -= UNIT
else: # 触碰到边界reward=-1并停留在原地
reach_boundary = True
elif action == self.action_space[1]: # down
if s[1] < MAZE_H:
base_action[1] += UNIT
else:
reach_boundary = True
elif action == self.action_space[2]: # right
if s[0] < MAZE_W:
base_action[0] += UNIT
else:
reach_boundary = True
elif action == self.action_space[3]: # left
if s[0] > 1:
base_action[0] -= UNIT
else:
reach_boundary = True
self.canvas.move(self.rect, base_action[0], base_action[1]) # move agent
s_ = self.get_state(self.rect) # next state
# reward function
if s_ == self.get_state(self.oval): # reach the terminal
reward = 1
done = True
s_ = 'success'
elif s_ == self.get_state(self.hell1): # reach the block
reward = -1
s_ = 'block_1'
done = False
elif s_ == self.get_state(self.hell2):
reward = -1
s_ = 'block_2'
done = False
else:
reward = 0
done = False
if reach_boundary:
reward = -1
return s_, reward, done
def render(self):
time.sleep(0.15)
self.update()
if __name__ == '__main__':
def test():
for t in range(10):
s = env.reset()
print(s)
while True:
env.render()
a = 'right'
s, r, done = env.step(a)
print(s)
if done:
break
env = Maze()
env.after(100, test) # 在延迟100ms后调用函数test
env.mainloop()
This part is important that the reward function design is include, which is as follows
reward = { 1 , if reach the cheese − 1 , if reach the trap or reach the boundary 0 , others \text{reward} = \left \{ \begin{aligned} & 1, \quad \text{if reach the cheese} \\ & -1, \quad \text{if reach the trap or reach the boundary} \\ & 0, \quad \text{others} \end{aligned} \right. reward=⎩ ⎨ ⎧1,if reach the cheese−1,if reach the trap or reach the boundary0,others
We need to explan some function of the class Maze
.
- First, the function
_build_maze
creates the inital maze location.
In this example we ++use the left up coordination of each grid as the state of each block++. - Second, the function
get_state
converts the coordination of each grid to numerical representation such as ( 1 , 1 ) , ( 1 , 2 ) , ⋯ (1,1),(1,2),\cdots (1,1),(1,2),⋯. - Third, the function
reset
renew the state which means placing the mouse in the original grid. - Then, the function
step
we let the agent interact with envrionment for one step, ang get the reward after the action. - Then, the function
render
controls updating the window.
2.3 Tabular Sarsa Algorithm
python
import numpy as np
import pandas as pd
class RL():
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
self.actions = actions # action list
self.lr = learning_rate
self.gamma = reward_decay
self.epsilon = e_greedy # epsilon greedy update policy
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table, use the coordinate as the observation
# self.q_table = self.q_table.append( # DataFrame.append is invalid
# pd.Series(
# [0]*len(self.actions),
# index=self.q_table.columns,
# name=state,
# )
# )
self.q_table = pd.concat(
[
self.q_table,
pd.DataFrame(
data=np.zeros((1,len(self.actions))),
columns = self.q_table.columns,
index = [state]
)
]
)
def choose_action(self, observation):
"""
Use the epsilon-greedy method to update policy
"""
self.check_state_exist(observation)
# action selection
# epsilon greedy algorithm
if np.random.uniform() < self.epsilon:
state_action = self.q_table.loc[observation, :]
# some actions may have the same value, randomly choose on in these actions
# state_action == np.max(state_action) generate bool mask
# choose best action
action = np.random.choice(state_action[state_action == np.max(state_action)].index)
else:
# choose random action
action = np.random.choice(self.actions)
return action
def learn(self, s, a, r, s_):
pass
class SarsaTable(RL):
"""
Implement Sarsa algorithm which is on-policy
"""
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable,self).__init__(actions, learning_rate, reward_decay, e_greedy)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'success' :
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # update
We store the Q-table as a DataFrame
of pandas
. The explanation of the functions are as follows.
- First, the function
check_state_exist
check the existence of one state, if not we append it to the Q-table. This is because once the state-action pair is visited, then we update it into the Q-table. - Second, the function
choose_action
is following the ϵ \epsilon ϵ-greedy algorithm
π ( a ∣ s ) = { 1 − ϵ ∣ A ( s ) ∣ ( ∣ A ( s ) ∣ − 1 ) , for the geedy action ϵ ∣ A ( s ) ∣ , for the other ∣ A ( s ) ∣ − 1 actions \pi(a|s) = \left \{ \begin{aligned} 1 - \frac{\epsilon}{|\mathcal{A}(s)|}(|\mathcal{A(s)}|-1), & \quad \text{for the geedy action} \\ \frac{\epsilon}{|\mathcal{A}(s)|}, & \quad \text{for the other } |\mathcal{A}(s)|-1 \text{ actions} \end{aligned} \right. π(a∣s)=⎩ ⎨ ⎧1−∣A(s)∣ϵ(∣A(s)∣−1),∣A(s)∣ϵ,for the geedy actionfor the other ∣A(s)∣−1 actions
- Third, the function
learn
is update the q value as Q-learning algorithm purposed, which relays on the sample ( s t , a t , r t + 1 , s t + 1 , a t + 1 ) \textcolor{red}{(s_t,a_t,r_{t+1},s_{t+1},a_{t+1})} (st,at,rt+1,st+1,at+1). The sample denotes current state, current action, immediate reward, next state and next action respectively.
Sarsa : { q t + 1 ( s t , a t ) = q t ( s t , a t ) − α t ( s t , a t ) [ q t ( s t , a t ) − ( r t + 1 + γ q t ( s t + 1 , a t + 1 ) ) ] q t + 1 ( s , a ) = q t ( s , a ) , for all ( s , a ) ≠ ( s t , a t ) \text{Sarsa} : \left \{ \begin{aligned} \textcolor{red}{q_{t+1}(s_t,a_t)} & \textcolor{red}{= q_t(s_t,a_t) - \alpha_t(s_t,a_t) \Big[q_t(s_t,a_t) - (r_{t+1}+ \gamma \ q_t(s_{t+1},a_{t+1})) \Big]} \\ \textcolor{red}{q_{t+1}(s,a)} & \textcolor{red}{= q_t(s,a)}, \quad \text{for all } (s,a) \ne (s_t,a_t) \end{aligned} \right. Sarsa:⎩ ⎨ ⎧qt+1(st,at)qt+1(s,a)=qt(st,at)−αt(st,at)[qt(st,at)−(rt+1+γ qt(st+1,at+1))]=qt(s,a),for all (s,a)=(st,at)
2.4 Run this main
Run this main script that we can run the all codes.
python
from maze_env_custom import Maze
from RL_brain import SarsaTable
MAX_EPISODE = 30
def update():
for episode in range(MAX_EPISODE):
# initial observation, observation is the rect's coordiante
# observation is [x0,y0, x1,y1]
observation = env.reset()
# RL choose action based on observation ['up', 'down', 'right', 'left']
action = RL.choose_action(str(observation))
while True:
# fresh env
env.render()
# RL take action and get next observation and reward
observation_, reward, done = env.step(action)
action_ = RL.choose_action(str(observation_))
# RL learn from this transition
RL.learn(str(observation), action, reward, str(observation_), action_)
# swap observation
observation = observation_
action = action_
# break while loop when end of this episode
if done:
break
# show q_table
print(RL.q_table)
print('\n')
# end of game
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(env.action_space)
env.after(100, update)
env.mainloop()
2.5 Check the Q table
After a long run we can check the q-table to judge wheter the learning is reasonable. The q-table is as follows:
python
up down right left
(1.0, 1.0) -6.837352e-02 -0.000135 -0.000266 -2.970185e-02
(2.0, 1.0) -4.901299e-02 -0.000334 -0.000484 -6.039572e-04
(2.0, 2.0) -3.988164e-04 -0.049010 -0.038785 -2.737623e-04
block_1 0.000000e+00 0.049010 0.000000 0.000000e+00
(4.0, 2.0) -2.646359e-04 0.001314 -0.019900 -1.000000e-02
(4.0, 1.0) -4.900994e-02 0.000014 -0.010000 -3.128178e-06
(3.0, 1.0) -2.970450e-02 -0.029433 -0.000516 -2.078845e-04
(1.0, 2.0) -4.933690e-04 -0.000374 -0.000951 -3.940947e-02
block_2 -1.979099e-07 0.000000 0.010000 -1.531800e-07
(1.0, 3.0) -3.525635e-04 -0.000056 -0.010000 -3.940439e-02
(1.0, 4.0) -7.194310e-07 -0.010000 0.000591 -1.990000e-02
(2.0, 4.0) -1.000000e-02 -0.019900 0.012381 0.000000e+00
(3.0, 4.0) 1.654862e-01 0.000000 0.000000 0.000000e+00
(4.0, 4.0) 0.000000e+00 0.000000 -0.010000 0.000000e+00
(4.0, 3.0) 0.000000e+00 0.000000 0.000000 5.851985e-02
success 0.000000e+00 0.000000 0.000000 0.000000e+00
For example, when at the original place if the mouse wants to move up or move left it will reach the boundary and get reward − 1 -1 −1. Hence the state value in q-table is minus.