Q-learning 极简教程

一、前言

Q-learning 是一种基于价值估计的方法,在深度强化学习之前,发挥巨大作用。Q-learning 也叫表格 Q-learning。其目的在于学习一个状态-动作到 Q 值映射的表格。在学习完表格后,Q-learning 算法通过既定的策略选择动作。

今天我们就来实现一下 Q-learning,并解决FrozenLake 问题。本文不涉及数学推导部分。

二、Q-learning

2.1 策略

首先我们假设已经得到了一个 Q-table,这个表格的键是状态-动作,而值是其对应的价值。在已有Q-table 的情况下,我们最简单的想法是在状态 s_t 下,选择 Q 值最大的a_t 动作执行。这种算法被称为贪婪策略,其实现如下:

python 复制代码
def greedy_policy(qtable, state):
    return np.argmax(qtable[state][:])

但有时候贪婪并不一定是最好的,比如:

如果选择贪婪算法,则执行的流程是 s0->a1->s1->a3->s3,获得回报 2。而实际上最佳的路线回报是 10.3。由于使用贪婪策略,我们永远无法探索到最优路线。为此,可以使用ε-greedy 策略替代。

ε-greedy 的思想是有大部分时间执行贪婪策略,小部分时间执行随机策略,这样就给学习带来了探索的能力。其实现非常简单:

python 复制代码
def epsilon_greedy_policy(qtable, state, epsilon):
    if random() > epsilon:
        return greedy_policy(qtable, state)
    else:
        return env.action_space.sample()

2.2 Q值更新

Q 值的定义如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ( s , a ) = E [ R t + 1 + γ R t + 2 + γ 2 R t + 3 + . . . ∣ S t = s , A t = a ] Q(s,a) = E[R_{t+1} + γR_{t+2} + γ²R_{t+3} + ... | S_t=s, A_t=a] </math>Q(s,a)=E[Rt+1+γRt+2+γ2Rt+3+...∣St=s,At=a]

其含义就是在状态 S_t 执行动作 A_t 后到游戏结束获得的总折扣价值。

最优Q函数Q*(s,a)满足贝尔曼最优方程:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ∗ ( s , a ) = E [ r + γ m a x a ′ Q ∗ ( s ′ , a ′ ) ] Q^*(s,a) = E[r + γ max_a' Q^*(s',a')] </math>Q∗(s,a)=E[r+γmaxa′Q∗(s′,a′)]

下面是最关键的 Q 函数更新计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Q ( s , a ) ← Q ( s , a ) + α [ r + γ m a x a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a) ← Q(s,a) + α[r + γ max_a' Q(s',a') - Q(s,a)] </math>Q(s,a)←Q(s,a)+α[r+γmaxa′Q(s′,a′)−Q(s,a)]

这里我们只需要专注于更新函数即可,这里关注右侧,有如下几个部分:

  • 首先是Q(s,a),是当前 Q-table 中存在的值,属于已知量。
  • α是学习率超参数。
  • r 是当前动作价值,由环境给出。
  • γ 是折扣超参数。
  • max a` Q(s`,a`),表示执行动作 a,得到状态s`,状态 s`对应的最大 Q 值。

r + γ max_a' Q(s',a') - Q(s,a)可以看作最优 Q 值相对当前 Q 值的提升,这里简称 A,而我们更新后 的 Q 值为Q(s,a)+αA,即为小幅提升 Q 值。

另外 A 可能是负数,因此也会存在降低 Q 值的情况。这样我们就理解了 Q-learning 最核心的内容。

Q 值更新的操作使用一行代码就可以完成:

Python 复制代码
q_table[state][action] = q_table[state][action] + lr * (
        reward + gamma * np.max(q_table[new_state]) - q_table[state][action]
)

三、Q-learning实现

Q-learning 算法流程如下:

  1. 初始化 Q-table
  2. 使用ε-greedy 算法选择动作
  3. 环境执行动作
  4. 更新 Q-table
  5. 循环往复

在开始前,我们实现两个辅助函数,用于评估和记录结果。

3.1 评估 agent

评估 agent 要做的就是,传入 Q-table,游玩 n 次游戏,记录得分即可:

python 复制代码
def evaluate_agent(env, max_steps, n_eval_episodes, q_table):
    episode_rewards = []
    for episode in tqdm(range(n_eval_episodes)):
        state, info = env.reset()
        total_rewards_ep = 0
        for step in range(max_steps):
            # 使用贪婪算法
            action = greedy_policy(q_table, state)
            new_state, reward, terminated, truncated, info = env.step(action)
            total_rewards_ep += reward

            if terminated or truncated:
                break

            state = new_state
        episode_rewards.append(total_rewards_ep)
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)
    return mean_reward, std_reward

这里代码比较简单,唯一需要注意的是,我们使用 greedy 而非ε-greedy,因为在学习阶段我们需要探索,而实际游玩则只需向着最高价值方向移动。

3.2 记录结果

这里我们使用gymnasium作为环境,并用 imageio 保存结果:

python 复制代码
def record_video(env: gym.Env, q_table, output_path, fps=1):
    images = []
    state, info = env.reset()
    img = env.render()
    images.append(img)

    for i in range(99):
        action = np.argmax(q_table[state][:])
        state, reward, terminated, truncated, info = env.step(action)
        img = env.render()
        images.append(img)
        if terminated:
            break
    imageio.mimsave(output_path, [np.array(img) for i, img in enumerate(images)], fps=fps)

3.3 Q-learning的训练

现在我们按照开头的步骤实现 Q-learning 的训练:

python 复制代码
def initialize_q_table(state_space, action_space):
    return np.zeros((state_space, action_space))


def train(n_training_episodes, lr, gamma, min_epsilon, max_epsilon, decay_rate, env: gym.Env, max_steps, q_table):
    for episode in tqdm(range(n_training_episodes)):
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * episode)
        state, info = env.reset()

        for step in range(max_steps):
            # 2、使用ε-greedy策略选择动作
            action = epsilon_greedy_policy(q_table, state, epsilon)
            # 3、执行动作
            new_state, reward, terminated, truncated, info = env.step(action)
            # 4、更新Q值
            q_table[state][action] = q_table[state][action] + lr * (
                    reward + gamma * np.max(q_table[new_state]) - q_table[state][action]
            )
            
            if terminated or truncated:
                break

            state = new_state

        if episode != 0 and episode % 1000 == 0:
            mean_reward, std_reward = evaluate_agent(env, max_steps, 100, q_table, None)
            print(f"Mean_reward={mean_reward:.2f} +/- {std_reward:.2f}")
            record_video(env, q_table, 'out.mp4')
    return q_table
    
    
def main():
    global env
    # 超参数
    n_training_episodes = 10000
    lr = 0.7
    n_eval_episodes = 100
    env_id = "FrozenLake-v1"
    max_steps = 99
    gamma = 0.99
    eval_seed = []

    max_epsilon = 1.0
    min_epsilon = 0.05
    decay_rate = 0.0005

    env = gym.make(env_id, map_name="4x4", is_slippery=False, render_mode="rgb_array")
    # 1、初始化Q-table
    q_table = initialize_q_table(env.observation_space.n, env.action_space.n)
    train(n_training_episodes, lr, gamma, min_epsilon, max_epsilon, decay_rate, env, max_steps, q_table)

这里我们用零矩阵初始化 Q-table,因此最开始是完全随机的。其余部分在上面都有详细介绍,这里不再复述。

四、总结

Q-learning 是一个非常简单的算法,使用上面代码我们还可以解决 Taxi-v3 问题。Q-learning 算法是经典的强化学习算法,但是由于要构建 Q-table,因此要求状态-动作的组合是有限的,且不能太多。像atari游戏则无法用 Q-learning 算法解决,由此就出现了 DQN 算法。感兴趣的读者可以查看相关资料。

相关推荐
豆沙沙包?13 分钟前
2025年- H82-Lc190--322.零钱兑换(动态规划)--Java版
java·算法·动态规划
Coovally AI模型快速验证35 分钟前
数据集分享 | 电力检测数据集,助力AI守护电网安全
人工智能·算法·安全·计算机视觉·目标跟踪
IT古董1 小时前
【第二章:机器学习与神经网络概述】01.聚类算法理论与实践-(1)K-means聚类算法
人工智能·算法·聚类
小wanga2 小时前
leetcode-hot100
算法·leetcode·职场和发展
完美的奶酪2 小时前
Leetcode-930. 和相同的二元子数组
算法·leetcode
onlyzzr2 小时前
备战秋招版 --- 第12题:85.最大矩形
算法
GG不是gg3 小时前
二分算法深度解析
算法
喵~来学编程啦3 小时前
【全队项目】从GAN到ESRGAN的超分辨率处理
开发语言·python·算法
机器学习之心3 小时前
光伏功率预测 | RF随机森林多变量单步光伏功率预测(Matlab完整源码和数据)
算法·随机森林·matlab·多变量单步光伏功率预测
倔强的石头_3 小时前
【数据结构与算法】归并排序:从理论到实践的深度剖析
后端·算法