时序差分法
前面所说的都是知道环境怎样的,如果完全陌生的环境呢?无模型的强化学习
时序差分方法
它是强化学习中估计 "状态价值" 的方法 ------"状态价值" 可以理解为:"在当前策略下,处于状态 s t s_t st时,未来能拿到的长期回报的期望"(简单说就是 "这个状态好不好")。
- 蒙特卡洛方法:
优点 :不用提前知道环境规则(比如游戏的 "状态转移概率"),直接从实际体验(样本数据)里学;
缺点 :必须等整个回合结束(比如一局游戏打完),才能算出 "总回报 G t G_t Gt",再更新价值。 - 动态规划方法:
优点 :不用等回合结束,每走一步就能用 "后续状态的价值" 更新当前状态的价值;
缺点:必须提前知道环境的完整规则(状态转移、奖励函数),但现实中很多环境(比如复杂游戏)的规则是未知的。
而时序差分方法同时解决了这两个问题:
- 像蒙特卡洛:不用环境规则,从实际体验(样本)里学;
- 像动态规划:不用等回合结束,每走一步就能更新价值。
先看蒙特卡洛的价值更新公式: V ( s t ) ← V ( s t ) + α [ G t − V ( s t ) ] V(s_t) \leftarrow V(s_t) + \alpha \left[ G_t - V(s_t) \right] V(st)←V(st)+α[Gt−V(st)]
- V ( s t ) V(s_t) V(st):当前对状态 s t s_t st的价值估计;
- G t G_t Gt:整个回合结束后得到的 "总回报"(比如从状态 s t s_t st开始,直到游戏结束的所有奖励之和,带折扣);
- α \alpha α:"步长"(学习率,控制每次更新的幅度);
- G t − V ( s t ) G_t - V(s_t) Gt−V(st):当前估计和实际总回报的误差,用这个误差来修正价值。
而时序差分的更新公式是: V ( s t ) ← V ( s t ) + α [ r t + γ V ( s t + 1 ) − V ( s t ) ] V(s_t) \leftarrow V(s_t) + \alpha \left[ r_t + \gamma V(s_{t+1}) - V(s_t) \right] V(st)←V(st)+α[rt+γV(st+1)−V(st)]它的关键变化是:不用等回合结束的 G t G_t Gt,而是用 "当前步奖励 + 下一个状态的价值估计" 来代替 G t G_t Gt。
从 "价值函数的定义" 推导来的:状态 s t s_t st的真实价值 V π ( s t ) V_\pi(s_t) Vπ(st),是 "从 s t s_t st出发,遵循策略 π \pi π得到的期望总回报"。而总回报 G t G_t Gt的定义是: G t = r t + γ r t + 1 + γ 2 r t + 2 + ... G_t = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \dots Gt=rt+γrt+1+γ2rt+2+...( γ \gamma γ是 "折扣因子":未来的奖励不如现在值钱,比如 γ = 0.9 \gamma=0.9 γ=0.9表示 "下一时刻的 1 单位奖励,相当于现在的 0.9 单位")
我们可以把 G t G_t Gt拆成 "当前奖励 + 未来的总回报": G t = r t + γ ( r t + 1 + γ r t + 2 + ... ) ⏟ G t + 1 G_t = r_t + \gamma \underbrace{\left( r_{t+1} + \gamma r_{t+2} + \dots \right)}{G{t+1}} Gt=rt+γGt+1 (rt+1+γrt+2+...)而 G t + 1 G_{t+1} Gt+1的期望,就是 "下一个状态 s t + 1 s_{t+1} st+1的价值 V π ( s t + 1 ) V_\pi(s_{t+1}) Vπ(st+1)"。
因此: V π ( s t ) = E [ r t + γ V π ( s t + 1 ) ∣ s t = s ] V_\pi(s_t) = \mathbb{E}\left[ r_t + \gamma V_\pi(s_{t+1}) \mid s_t = s \right] Vπ(st)=E[rt+γVπ(st+1)∣st=s]所以:蒙特卡洛是用 "实际的总回报 G t G_t Gt"(必须等回合结束)作为更新目标;时序差分是用 " r t + γ V ( s t + 1 ) r_t + \gamma V(s_{t+1}) rt+γV(st+1)"(当前奖励 + 下一个状态的价值估计)作为更新目标 ------这一步走完就能算,不用等回合结束。
时序差分的核心优势
- 在线学习:每走一步就能更新价值,不用等回合结束,效率更高;
- 无模型:不用提前知道环境规则(像蒙特卡洛),能直接从实际体验里学;
- 稳定收敛:虽然用的是 "下一个状态的价值估计"(不是真实值),但理论上能收敛到真实的状态价值。
Sarsa
Sarsa 是在线策略(On-Policy)的时序差分算法:
"在线策略":它用当前正在执行的策略(也就是下文的 "ε- 贪婪策略")来生成样本(选动作、和环境交互),同时用这个样本更新策略 ------ 策略和采样用的是同一个,所以是 "在线" 的。
作用:学习动作价值函数 Q (s,a)(Q (s,a) 表示 "在状态 s 下执行动作 a,未来能拿到的长期回报的期望"),再用 Q (s,a) 指导动作选择。
之前的时序差分(TD)是更新状态价值 V (s),而 Sarsa 是更新状态 - 动作对的价值 Q (s,a),核心公式是: Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + γ Q ( s t + 1 , a t + 1 ) − Q ( s t , a t ) ] Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[ r_t + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t) \right] Q(st,at)←Q(st,at)+α[rt+γQ(st+1,at+1)−Q(st,at)]对比 TD 的 V (s) 更新:TD 更新的是 V ( s t ) V(s_t) V(st)(状态 s 的价值),而 Sarsa 更新的是 Q ( s t , a t ) Q(s_t,a_t) Q(st,at)(状态 s 下选动作 a 的价值);
关键差异:Sarsa 的更新需要下一个状态 s' 对应的下一个动作 a' 的 Q 值(即 Q (s',a')),因为 Q 是 "状态 - 动作对" 的价值,必须明确下一个动作是什么。
直接用 "贪婪策略"(选当前 Q (s,a) 最大的动作)会有问题:
只 "利用" 当前认为最好的动作,不 "探索" 其他动作 ------ 可能错过更优的动作,陷入局部最优。
因此 Sarsa 用ε- 贪婪策略平衡 "探索(Exploration)" 和 "利用(Exploitation)": π ( a ∣ s ) = { 1 − ε + ε ∣ A ∣ 如果 a = a r g m a x a ′ Q ( s , a ′ ) (利用:选当前 Q 最大的动作) ε ∣ A ∣ 其他动作(探索:随机选动作) π(a∣s)=\left\{\begin{matrix}1-ε+ \frac{ε}{|A|} & 如果 a=argmaxa′Q(s,a′)(利用:选当前Q最大的动作)\\ \frac{ε}{|A|} & 其他动作(探索:随机选动作) \end{matrix}\right. π(a∣s)={1−ε+∣A∣ε∣A∣ε如果a=argmaxa′Q(s,a′)(利用:选当前Q最大的动作)其他动作(探索:随机选动作)
∣ A ∣ ∣A∣ ∣A∣:动作空间的大小(有多少个可选动作);
ε ε ε:探索概率(通常是小值,比如 0.1)------1-ε 的概率 "利用",ε 的概率 "探索"。
Sarsa 的名字,是从它更新时用到的5 个核心元素来的:当前状态s、当前动作a、奖励r、下一个状态s′、下一个动作a′这 5 个元素的首字母拼起来就是 S(s)-A(a)-R(r)-S(s')-A(a')→ Sarsa,非常好记!
步骤:
- 初始化 Q (s,a):先给所有 "状态 - 动作对" 的 Q 值赋初始值(比如随机初始化,或全设为 0);
- 每个回合(Episode)循环:从第 1 个回合到第 E 个回合(比如玩 E 局游戏);
- 得到初始状态 s:比如游戏开始时的初始位置;
- 用 ε- 贪婪策略选初始动作 a:根据当前 Q (s,a),用 ε- 贪婪选第一个动作;
- 每个时间步(Step)循环:从第 1 步到第 T 步(直到回合结束,比如游戏通关 / 失败);
- 执行动作 a,得到奖励 r 和下一个状态 s':和环境交互(比如执行 "向右走",得到奖励 + 1,进入新位置 s');
- 用 ε- 贪婪策略选 s' 对应的动作 a':根据新状态 s' 的 Q (s',a'),再用 ε- 贪婪选下一个动作 a';
- 更新 Q (s,a):用 Sarsa 的核心公式更新当前状态 - 动作对的 Q 值;
- 状态和动作转移:把当前状态 s 换成 s',当前动作 a 换成 a'(为下一个时间步做准备);
- 结束时间步循环;
- 结束回合循环。
python
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm # tqdm是显示循环进度条的库
class CliffWalkingEnv:
def __init__(self, ncol, nrow):
self.nrow = nrow
self.ncol = ncol
self.x = 0 # 记录当前智能体位置的横坐标
self.y = self.nrow - 1 # 记录当前智能体位置的纵坐标
def step(self, action): # 外部调用这个函数来改变当前位置
# 4种动作, change[0]:上, change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
# 定义在左上角
change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
self.x = min(self.ncol - 1, max(0, self.x + change[action][0]))
self.y = min(self.nrow - 1, max(0, self.y + change[action][1]))
next_state = self.y * self.ncol + self.x
reward = -1
done = False
if self.y == self.nrow - 1 and self.x > 0: # 下一个位置在悬崖或者目标
done = True
if self.x != self.ncol - 1:
reward = -100
return next_state, reward, done
def reset(self): # 回归初始状态,坐标轴原点在左上角
self.x = 0
self.y = self.nrow - 1
return self.y * self.ncol + self.x
python
class Sarsa:
""" Sarsa算法 """
def __init__(self, ncol, nrow, epsilon, alpha, gamma, n_action=4):
self.Q_table = np.zeros([nrow * ncol, n_action]) # 初始化Q(s,a)表格
self.n_action = n_action # 动作个数
self.alpha = alpha # 学习率
self.gamma = gamma # 折扣因子
self.epsilon = epsilon # epsilon-贪婪策略中的参数
def take_action(self, state): # 选取下一步的操作,具体实现为epsilon-贪婪
if np.random.random() < self.epsilon:
action = np.random.randint(self.n_action)
else:
action = np.argmax(self.Q_table[state])
return action
def best_action(self, state): # 用于打印策略
Q_max = np.max(self.Q_table[state])
a = [0 for _ in range(self.n_action)]
for i in range(self.n_action): # 若两个动作的价值一样,都会记录下来
if self.Q_table[state, i] == Q_max:
a[i] = 1
return a
def update(self, s0, a0, r, s1, a1):
td_error = r + self.gamma * self.Q_table[s1, a1] - self.Q_table[s0, a0]
self.Q_table[s0, a0] += self.alpha * td_error
python
ncol = 12
nrow = 4
env = CliffWalkingEnv(ncol, nrow)
np.random.seed(0)
epsilon = 0.1
alpha = 0.1
gamma = 0.9
agent = Sarsa(ncol, nrow, epsilon, alpha, gamma)
num_episodes = 500 # 智能体在环境中运行的序列的数量
return_list = [] # 记录每一条序列的回报
for i in range(10): # 显示10个进度条
# tqdm的进度条功能
with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
for i_episode in range(int(num_episodes / 10)): # 每个进度条的序列数
episode_return = 0
state = env.reset()
action = agent.take_action(state)
done = False
while not done:
next_state, reward, done = env.step(action)
next_action = agent.take_action(next_state)
episode_return += reward # 这里回报的计算不进行折扣因子衰减
agent.update(state, action, reward, next_state, next_action)
state = next_state
action = next_action
return_list.append(episode_return)
if (i_episode + 1) % 10 == 0: # 每10条序列打印一下这10条序列的平均回报
pbar.set_postfix({
'episode':
'%d' % (num_episodes / 10 * i + i_episode + 1),
'return':
'%.3f' % np.mean(return_list[-10:])
})
pbar.update(1)
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Sarsa on {}'.format('Cliff Walking'))
plt.show()
Iteration 0: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1113.33it/s, episode=50, return=-119.400]
Iteration 1: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1221.50it/s, episode=100, return=-63.000]
Iteration 2: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1474.46it/s, episode=150, return=-51.200]
Iteration 3: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2089.03it/s, episode=200, return=-48.100]
Iteration 4: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1671.08it/s, episode=250, return=-35.700]
Iteration 5: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2385.76it/s, episode=300, return=-29.900]
Iteration 6: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2506.85it/s, episode=350, return=-28.300]
Iteration 7: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2638.92it/s, episode=400, return=-27.700]
Iteration 8: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2506.25it/s, episode=450, return=-28.500]
Iteration 9: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2638.69it/s, episode=500, return=-18.900]
python
def print_agent(agent, env, action_meaning, disaster=[], end=[]):
for i in range(env.nrow):
for j in range(env.ncol):
if (i * env.ncol + j) in disaster:
print('****', end=' ')
elif (i * env.ncol + j) in end:
print('EEEE', end=' ')
else:
a = agent.best_action(i * env.ncol + j)
pi_str = ''
for k in range(len(action_meaning)):
pi_str += action_meaning[k] if a[k] > 0 else 'o'
print(pi_str, end=' ')
print()
action_meaning = ['^', 'v', '<', '>']
print('Sarsa算法最终收敛得到的策略为:')
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])
Q-learning算法最终收敛得到的策略为:
^ooo ovoo ovoo ^ooo ^ooo ovoo ooo> ^ooo ^ooo ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ^ooo ooo> ooo> ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE
多步 Sarsa 算法
- 蒙特卡洛方法:
靠 "完整走一遍流程" 来算价值(比如玩完一整局游戏,把每一步的奖励加起来)。
优点:无偏(因为用的是真实的奖励总和,没猜);
缺点:方差大(每局的结果波动大,比如游戏里每一步的随机事件会让最终奖励差很多)。 - 时序差分(TD)方法:
只走 "一步",靠 "当前奖励 + 下一个状态的预估价值" 来算。
优点:方差小(只看一步,波动小);
缺点:有偏(因为用了 "下一个状态的预估价值",不是真实值)。
思路:不走 "1 步",也不走 "完整流程",走n 步:
原来的 TD(1 步)公式是:
用 "当前奖励 + 下 1 个状态的预估价值" 来更新。
多步 TD 的公式是:
用 "当前到 n 步的奖励总和 + 第 n+1 个状态的预估价值" 来更新。
多步 Sarsa(多步 TD 的具体应用)
Sarsa 是一种 TD 算法,多步 Sarsa 就是把 Sarsa 的 "1 步更新" 改成 "n 步更新":
原来的 Sarsa(1 步):
用 "当前奖励 + 下 1 步状态的动作价值" 来更新当前动作的价值。
多步 Sarsa:
用 "当前到 n 步的奖励总和 + 第 n+1 步状态的动作价值" 来更新当前动作的价值。
简单总结:多步时序差分是 "取中间值"------ 既不像蒙特卡洛那样等完整流程(减少方差),也不像 1 步 TD 那样只看 1 步(减少偏差),靠 "走 n 步" 平衡了无偏和方差的问题。
python
class nstep_Sarsa:
""" n步Sarsa算法 """
def __init__(self, n, ncol, nrow, epsilon, alpha, gamma, n_action=4):
self.Q_table = np.zeros([nrow * ncol, n_action])
self.n_action = n_action
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.n = n # 采用n步Sarsa算法
self.state_list = [] # 保存之前的状态
self.action_list = [] # 保存之前的动作
self.reward_list = [] # 保存之前的奖励
def take_action(self, state):
if np.random.random() < self.epsilon:
action = np.random.randint(self.n_action)
else:
action = np.argmax(self.Q_table[state])
return action
def best_action(self, state): # 用于打印策略
Q_max = np.max(self.Q_table[state])
a = [0 for _ in range(self.n_action)]
for i in range(self.n_action):
if self.Q_table[state, i] == Q_max:
a[i] = 1
return a
def update(self, s0, a0, r, s1, a1, done):
self.state_list.append(s0)
self.action_list.append(a0)
self.reward_list.append(r)
if len(self.state_list) == self.n: # 若保存的数据可以进行n步更新
G = self.Q_table[s1, a1] # 得到Q(s_{t+n}, a_{t+n})
for i in reversed(range(self.n)):
G = self.gamma * G + self.reward_list[i] # 不断向前计算每一步的回报
# 如果到达终止状态,最后几步虽然长度不够n步,也将其进行更新
if done and i > 0:
s = self.state_list[i]
a = self.action_list[i]
self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])
s = self.state_list.pop(0) # 将需要更新的状态动作从列表中删除,下次不必更新
a = self.action_list.pop(0)
self.reward_list.pop(0)
# n步Sarsa的主要更新步骤
self.Q_table[s, a] += self.alpha * (G - self.Q_table[s, a])
if done: # 如果到达终止状态,即将开始下一条序列,则将列表全清空
self.state_list = []
self.action_list = []
self.reward_list = []
python
np.random.seed(0)
n_step = 5 # 5步Sarsa算法
alpha = 0.1
epsilon = 0.1
gamma = 0.9
agent = nstep_Sarsa(n_step, ncol, nrow, epsilon, alpha, gamma)
num_episodes = 500 # 智能体在环境中运行的序列的数量
return_list = [] # 记录每一条序列的回报
for i in range(10): # 显示10个进度条
#tqdm的进度条功能
with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
for i_episode in range(int(num_episodes / 10)): # 每个进度条的序列数
episode_return = 0
state = env.reset()
action = agent.take_action(state)
done = False
while not done:
next_state, reward, done = env.step(action)
next_action = agent.take_action(next_state)
episode_return += reward # 这里回报的计算不进行折扣因子衰减
agent.update(state, action, reward, next_state, next_action,
done)
state = next_state
action = next_action
return_list.append(episode_return)
if (i_episode + 1) % 10 == 0: # 每10条序列打印一下这10条序列的平均回报
pbar.set_postfix({
'episode':
'%d' % (num_episodes / 10 * i + i_episode + 1),
'return':
'%.3f' % np.mean(return_list[-10:])
})
pbar.update(1)
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('5-step Sarsa on {}'.format('Cliff Walking'))
plt.show()
Iteration 0: 100%|████████████████████████████████████████| 50/50 [00:00<00:00, 781.85it/s, episode=50, return=-26.500]
Iteration 1: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2179.81it/s, episode=100, return=-35.200]
Iteration 2: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2179.02it/s, episode=150, return=-20.100]
Iteration 3: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2179.88it/s, episode=200, return=-27.200]
Iteration 4: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1856.81it/s, episode=250, return=-19.300]
Iteration 5: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2088.82it/s, episode=300, return=-27.400]
Iteration 6: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2005.60it/s, episode=350, return=-28.000]
Iteration 7: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1928.19it/s, episode=400, return=-36.500]
Iteration 8: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2179.54it/s, episode=450, return=-27.000]
Iteration 9: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2278.89it/s, episode=500, return=-19.100]
python
action_meaning = ['^', 'v', '<', '>']
print('5步Sarsa算法最终收敛得到的策略为:')
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])
5步Sarsa算法最终收敛得到的策略为:
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo ^ooo ^ooo oo<o ^ooo ^ooo ^ooo ^ooo ooo> ooo> ^ooo ovoo
ooo> ^ooo ^ooo ^ooo ^ooo ^ooo ^ooo ooo> ooo> ^ooo ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE
Q-learning 算法
Q ( s t , a t ) ← Q ( s t , a t ) + α [ R t + γ m a x a Q ( s t + 1 , a ) − Q ( s t , a t ) ] Q(s t ,a t )←Q(s t ,a t )+α[R t +γmax a Q(s t+1 ,a)−Q(s t ,a t )] Q(st,at)←Q(st,at)+α[Rt+γmaxaQ(st+1,a)−Q(st,at)]逻辑:用 "下一个状态的最优动作价值" 来更新当前动作价值(不管下一个动作实际选了啥,只取最大值)。
Q-learning 的算法流程(白话版)
- 先给所有 "状态 - 动作" 的 Q 值随便设个初始值;
- 重复和环境互动(E 个序列):
- 从初始状态 s 开始;
- 每一步(T 步):
- 用 ε-greedy 策略选当前状态 s 的动作 a(ε 概率随机选,1-ε 概率选当前 Q 最大的动作);
- 执行 a,得到奖励 r 和新状态 s';
- 用上面的公式更新 Q (s,a)(取 s' 下所有动作的最大 Q 值);
- 状态切换到 s',继续循环。
| 维度 | Sarsa(在线策略 on-policy) | Q-learning(离线策略 off-policy) |
|---|---|---|
| 下一个动作的 Q 值 | 用 当前策略实际选的下一个动作 a' 的 Q 值: Q ( s ′ , a ′ ) Q(s′,a′) Q(s′,a′) | 用下一个状态的最优动作的 Q 值: m a x a Q ( s ′ , a ) max_aQ(s′,a) maxaQ(s′,a) |
| 策略依赖 | 必须用当前策略(ε-greedy)采样的数据更新("边用边学") | 可以用任意策略采样的数据更新("学最优,用探索") |
Sarsa(在线) :更新 Q 值用的是 "当前策略实际和环境互动得到的数据"(自己用自己学)。
Q-learning(离线):更新 Q 值用的是 "最优策略的 Q 值",但和环境互动用的是探索策略(比如 ε-greedy)(学的是最优,用的是探索)。
| 术语 | 含义 |
|---|---|
| 行为策略 | 用于和环境交互、采集数据的策略(比如 ε-greedy 策略) |
| 目标策略 | 用于更新价值函数(如 Q 值)的策略 |
| 在线策略算法 | 行为策略 = 目标策略(用自己采集的数据更新自己) |
| 离线策略算法 | 行为策略 ≠ 目标策略(可用其他策略采集的数据更新目标策略) |
| 算法 | 数据依赖 | 策略关系 |
|---|---|---|
| Sarsa | 更新需要当前策略采集的五元组(s,a,r,s',a') | (a' 是当前策略选的动作) 行为策略 = 目标策略 |
| Q-learning | 更新仅需四元组(s,a,r,s') | (a' 是 s' 下的最优动作,无需当前策略采集) 行为策略≠目标策略 |
python
class QLearning:
""" Q-learning算法 """
def __init__(self, ncol, nrow, epsilon, alpha, gamma, n_action=4):
self.Q_table = np.zeros([nrow * ncol, n_action]) # 初始化Q(s,a)表格
self.n_action = n_action # 动作个数
self.alpha = alpha # 学习率
self.gamma = gamma # 折扣因子
self.epsilon = epsilon # epsilon-贪婪策略中的参数
def take_action(self, state): #选取下一步的操作
if np.random.random() < self.epsilon:
action = np.random.randint(self.n_action)
else:
action = np.argmax(self.Q_table[state])
return action
def best_action(self, state): # 用于打印策略
Q_max = np.max(self.Q_table[state])
a = [0 for _ in range(self.n_action)]
for i in range(self.n_action):
if self.Q_table[state, i] == Q_max:
a[i] = 1
return a
def update(self, s0, a0, r, s1):
td_error = r + self.gamma * self.Q_table[s1].max(
) - self.Q_table[s0, a0]
self.Q_table[s0, a0] += self.alpha * td_error
python
np.random.seed(0)
epsilon = 0.1
alpha = 0.1
gamma = 0.9
agent = QLearning(ncol, nrow, epsilon, alpha, gamma)
num_episodes = 500 # 智能体在环境中运行的序列的数量
return_list = [] # 记录每一条序列的回报
for i in range(10): # 显示10个进度条
# tqdm的进度条功能
with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
for i_episode in range(int(num_episodes / 10)): # 每个进度条的序列数
episode_return = 0
state = env.reset()
done = False
while not done:
action = agent.take_action(state)
next_state, reward, done = env.step(action)
episode_return += reward # 这里回报的计算不进行折扣因子衰减
agent.update(state, action, reward, next_state)
state = next_state
return_list.append(episode_return)
if (i_episode + 1) % 10 == 0: # 每10条序列打印一下这10条序列的平均回报
pbar.set_postfix({
'episode':
'%d' % (num_episodes / 10 * i + i_episode + 1),
'return':
'%.3f' % np.mean(return_list[-10:])
})
pbar.update(1)
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Q-learning on {}'.format('Cliff Walking'))
plt.show()
action_meaning = ['^', 'v', '<', '>']
print('Q-learning算法最终收敛得到的策略为:')
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])
Iteration 0: 100%|███████████████████████████████████████| 50/50 [00:00<00:00, 759.53it/s, episode=50, return=-105.700]
Iteration 1: 100%|███████████████████████████████████████| 50/50 [00:00<00:00, 964.28it/s, episode=100, return=-70.900]
Iteration 2: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1253.54it/s, episode=150, return=-56.500]
Iteration 3: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1474.54it/s, episode=200, return=-46.500]
Iteration 4: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1821.22it/s, episode=250, return=-40.800]
Iteration 5: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 1856.86it/s, episode=300, return=-20.400]
Iteration 6: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2089.23it/s, episode=350, return=-45.700]
Iteration 7: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2387.31it/s, episode=400, return=-32.800]
Iteration 8: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2506.76it/s, episode=450, return=-22.700]
Iteration 9: 100%|██████████████████████████████████████| 50/50 [00:00<00:00, 2638.76it/s, episode=500, return=-61.700]
Q-learning算法最终收敛得到的策略为:
^ooo ovoo ovoo ^ooo ^ooo ovoo ooo> ^ooo ^ooo ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ^ooo ooo> ooo> ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE


