Q-learning 极简教程

一、前言

Q-learning 是一种基于价值估计的方法,在深度强化学习之前,发挥巨大作用。Q-learning 也叫表格 Q-learning。其目的在于学习一个状态-动作到 Q 值映射的表格。在学习完表格后,Q-learning 算法通过既定的策略选择动作。

今天我们就来实现一下 Q-learning,并解决FrozenLake 问题。本文不涉及数学推导部分。

二、Q-learning

2.1 策略

首先我们假设已经得到了一个 Q-table,这个表格的键是状态-动作,而值是其对应的价值。在已有Q-table 的情况下,我们最简单的想法是在状态 s_t 下,选择 Q 值最大的a_t 动作执行。这种算法被称为贪婪策略,其实现如下:

python 复制代码
def greedy_policy(qtable, state):
    return np.argmax(qtable[state][:])

但有时候贪婪并不一定是最好的,比如:

如果选择贪婪算法,则执行的流程是 s0->a1->s1->a3->s3,获得回报 2。而实际上最佳的路线回报是 10.3。由于使用贪婪策略,我们永远无法探索到最优路线。为此,可以使用ε-greedy 策略替代。

ε-greedy 的思想是有大部分时间执行贪婪策略,小部分时间执行随机策略,这样就给学习带来了探索的能力。其实现非常简单:

python 复制代码
def epsilon_greedy_policy(qtable, state, epsilon):
    if random() > epsilon:
        return greedy_policy(qtable, state)
    else:
        return env.action_space.sample()

2.2 Q值更新

Q 值的定义如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ( s , a ) = E [ R t + 1 + γ R t + 2 + γ 2 R t + 3 + . . . ∣ S t = s , A t = a ] Q(s,a) = E[R_{t+1} + γR_{t+2} + γ²R_{t+3} + ... | S_t=s, A_t=a] </math>Q(s,a)=E[Rt+1+γRt+2+γ2Rt+3+...∣St=s,At=a]

其含义就是在状态 S_t 执行动作 A_t 后到游戏结束获得的总折扣价值。

最优Q函数Q*(s,a)满足贝尔曼最优方程:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ∗ ( s , a ) = E [ r + γ m a x a ′ Q ∗ ( s ′ , a ′ ) ] Q^*(s,a) = E[r + γ max_a' Q^*(s',a')] </math>Q∗(s,a)=E[r+γmaxa′Q∗(s′,a′)]

下面是最关键的 Q 函数更新计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ( s , a ) ← Q ( s , a ) + α [ r + γ m a x a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a) ← Q(s,a) + α[r + γ max_a' Q(s',a') - Q(s,a)] </math>Q(s,a)←Q(s,a)+α[r+γmaxa′Q(s′,a′)−Q(s,a)]

这里我们只需要专注于更新函数即可,这里关注右侧,有如下几个部分:

  • 首先是Q(s,a),是当前 Q-table 中存在的值,属于已知量。
  • α是学习率超参数。
  • r 是当前动作价值,由环境给出。
  • γ 是折扣超参数。
  • max a` Q(s`,a`),表示执行动作 a,得到状态s`,状态 s`对应的最大 Q 值。

r + γ max_a' Q(s',a') - Q(s,a)可以看作最优 Q 值相对当前 Q 值的提升,这里简称 A,而我们更新后 的 Q 值为Q(s,a)+αA,即为小幅提升 Q 值。

另外 A 可能是负数,因此也会存在降低 Q 值的情况。这样我们就理解了 Q-learning 最核心的内容。

Q 值更新的操作使用一行代码就可以完成:

Python 复制代码
q_table[state][action] = q_table[state][action] + lr * (
        reward + gamma * np.max(q_table[new_state]) - q_table[state][action]
)

三、Q-learning实现

Q-learning 算法流程如下:

  1. 初始化 Q-table
  2. 使用ε-greedy 算法选择动作
  3. 环境执行动作
  4. 更新 Q-table
  5. 循环往复

在开始前,我们实现两个辅助函数,用于评估和记录结果。

3.1 评估 agent

评估 agent 要做的就是,传入 Q-table,游玩 n 次游戏,记录得分即可:

python 复制代码
def evaluate_agent(env, max_steps, n_eval_episodes, q_table):
    episode_rewards = []
    for episode in tqdm(range(n_eval_episodes)):
        state, info = env.reset()
        total_rewards_ep = 0
        for step in range(max_steps):
            # 使用贪婪算法
            action = greedy_policy(q_table, state)
            new_state, reward, terminated, truncated, info = env.step(action)
            total_rewards_ep += reward

            if terminated or truncated:
                break

            state = new_state
        episode_rewards.append(total_rewards_ep)
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)
    return mean_reward, std_reward

这里代码比较简单,唯一需要注意的是,我们使用 greedy 而非ε-greedy,因为在学习阶段我们需要探索,而实际游玩则只需向着最高价值方向移动。

3.2 记录结果

这里我们使用gymnasium作为环境,并用 imageio 保存结果:

python 复制代码
def record_video(env: gym.Env, q_table, output_path, fps=1):
    images = []
    state, info = env.reset()
    img = env.render()
    images.append(img)

    for i in range(99):
        action = np.argmax(q_table[state][:])
        state, reward, terminated, truncated, info = env.step(action)
        img = env.render()
        images.append(img)
        if terminated:
            break
    imageio.mimsave(output_path, [np.array(img) for i, img in enumerate(images)], fps=fps)

3.3 Q-learning的训练

现在我们按照开头的步骤实现 Q-learning 的训练:

python 复制代码
def initialize_q_table(state_space, action_space):
    return np.zeros((state_space, action_space))


def train(n_training_episodes, lr, gamma, min_epsilon, max_epsilon, decay_rate, env: gym.Env, max_steps, q_table):
    for episode in tqdm(range(n_training_episodes)):
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
        state, info = env.reset()

        for step in range(max_steps):
            # 2、使用ε-greedy策略选择动作
            action = epsilon_greedy_policy(q_table, state, epsilon)
            # 3、执行动作
            new_state, reward, terminated, truncated, info = env.step(action)
            # 4、更新Q值
            q_table[state][action] = q_table[state][action] + lr * (
                    reward + gamma * np.max(q_table[new_state]) - q_table[state][action]
            )
            
            if terminated or truncated:
                break

            state = new_state

        if episode != 0 and episode % 1000 == 0:
            mean_reward, std_reward = evaluate_agent(env, max_steps, 100, q_table, None)
            print(f"Mean_reward={mean_reward:.2f} +/- {std_reward:.2f}")
            record_video(env, q_table, 'out.mp4')
    return q_table
    
    
def main():
    global env
    # 超参数
    n_training_episodes = 10000
    lr = 0.7
    n_eval_episodes = 100
    env_id = "FrozenLake-v1"
    max_steps = 99
    gamma = 0.99
    eval_seed = []

    max_epsilon = 1.0
    min_epsilon = 0.05
    decay_rate = 0.0005

    env = gym.make(env_id, map_name="4x4", is_slippery=False, render_mode="rgb_array")
    # 1、初始化Q-table
    q_table = initialize_q_table(env.observation_space.n, env.action_space.n)
    train(n_training_episodes, lr, gamma, min_epsilon, max_epsilon, decay_rate, env, max_steps, q_table)

这里我们用零矩阵初始化 Q-table,因此最开始是完全随机的。其余部分在上面都有详细介绍,这里不再复述。

四、总结

Q-learning 是一个非常简单的算法,使用上面代码我们还可以解决 Taxi-v3 问题。Q-learning 算法是经典的强化学习算法,但是由于要构建 Q-table,因此要求状态-动作的组合是有限的,且不能太多。像atari游戏则无法用 Q-learning 算法解决,由此就出现了 DQN 算法。感兴趣的读者可以查看相关资料。

相关推荐
晨晖221 小时前
顺序查找:c语言
c语言·开发语言·算法
LYFlied21 小时前
【每日算法】LeetCode 64. 最小路径和(多维动态规划)
数据结构·算法·leetcode·动态规划
Salt_07281 天前
DAY44 简单 CNN
python·深度学习·神经网络·算法·机器学习·计算机视觉·cnn
货拉拉技术1 天前
AI拍货选车,开启拉货新体验
算法
MobotStone1 天前
一夜蒸发1000亿美元后,Google用什么夺回AI王座
算法
Wang201220131 天前
RNN和LSTM对比
人工智能·算法·架构
xueyongfu1 天前
从Diffusion到VLA pi0(π0)
人工智能·算法·stable diffusion
永远睡不够的入1 天前
快排(非递归)和归并的实现
数据结构·算法·深度优先
cheems95271 天前
二叉树深搜算法练习(一)
数据结构·算法
sin_hielo1 天前
leetcode 3074
数据结构·算法·leetcode