五分钟入门 强化学习---Q-Learning算法与实现

蒙特卡洛方法(MC)必须等一把游戏打完才能更新,而动态规划(DP)又必须偷看"游戏源码"(知道状态转移概率 P)。

Q-Learning 完美结合了这两者的优势:它既不需要知道环境规律,又可以走一步更新一步!

为了让你透彻理解,我将从它的执行流程、优劣势以及适用场景为你详细拆解。

一、 Q-Learning 的执行流程

Q-Learning 属于时间差分(Temporal Difference, TD)算法。它的核心思想是:用"下一步的估计值"来更新"当前的估计值"。它维护着一张巨大的 Q 表,记录着每一个状态下执行每一个动作的预期打分。

它的具体流程如下:

二、 Q-Learning 的优势

  1. 无模型学习 (Model-Free): 与动态规划不同,Q-Learning 完全不需要知道状态转移概率 P 和奖励函数 R。把它扔进一个完全未知的游戏里,它靠自己摸爬滚打就能学出来。

  2. 离策略 (Off-Policy) 的王者: 它是强化学习中最著名的离策略算法。注意看它的更新公式:它在实际行动时用的是 \epsilon-贪心策略(带随机乱走),但在更新 Q 表时,它假装自己未来一定会走最优的一步(用的是)。这意味着它可以一边勇敢甚至鲁莽地探索未知,一边在脑海中冷静地优化出一条完美的通关路线

  3. 单步更新 (TD Learning): 相比于蒙特卡洛方法必须等游戏结束(比如可能要走 1000 步)才能算总账,Q-Learning 每走一步就能立刻更新一次 Q 表。这让它的学习效率更高,并且可以处理没有终点的连续任务。

三、 Q-Learning 的劣势

  1. 维度灾难: Q-Learning 本质上还是"查表法"。如果面对围棋(状态空间无限大)或自动驾驶(连续的画面和方向盘角度),内存根本存不下这张 Q 表,算法会直接失效。

  2. 过估计偏差 (Overestimation Bias): 更新公式中的那个 \max 操作其实是个双刃剑。在游戏初期,Q 值都在随机波动,由于每次都取 \max,它很容易把由于随机噪声产生的高估值当成真实的奖励,导致智能体盲目乐观,去踩不该踩的坑。

  3. 收敛速度对超参数敏感: 学习率 \alpha 和探索率 \epsilon 的衰减策略如果没有设置好,算法很容易在终点附近来回震荡,或者陷入局部最优解出不来。

四、 适用场景

  • 离散且小规模的环境: 比如走迷宫、控制电梯调度、简单的棋盘游戏、或者把复杂问题高度抽象降维后的系统。

  • 规则完全未知的"黑盒"系统: 只要系统能接受输入(动作)并反馈当前状态和奖励,Q-Learning 就能工作。

  • 需要重用历史数据的场景: 因为它是离策略的,你可以把别人玩游戏产生的数据(或者是历史遗留的数据)直接丢给它,它照样能从中学到最优策略。

代码实现:

algorithm.py

python 复制代码
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

import numpy as np  # 导入numpy库,用于矩阵计算和查找最大值(np.max)

class Algorithm:  # 定义 Q-Learning 核心算法类
    def __init__(self, gamma, learning_rate, state_size, action_size):  # 初始化函数,传入折扣因子、学习率、状态空间和动作空间大小
        self.gamma = gamma  # 存储折扣因子 (Gamma),决定对未来长远奖励的重视程度
        self.learning_rate = learning_rate  # 存储学习率 (lr/Alpha),控制每次看清现实后修正记忆的幅度
        self.state_size = state_size  # 存储环境中的总状态数量
        self.action_size = action_size  # 存储智能体可以采取的总动作数量

        self.Q = np.ones([self.state_size, self.action_size])  # 初始化 Q 表,创建一个大小为 [状态数, 动作数] 且初始值全为 1 的矩阵

    def learn(self, list_sample_data):  # 定义学习函数,利用单步交互产生的样本数据来更新 Q 表
        sample = list_sample_data[0]  # 时间差分算法(TD)支持单步更新,直接取出当前批量中的第一个单步样本字典
        
        # 解包样本字典,分别提取出:当前状态 s、采取的动作 a、获得的即时奖励 reward 以及新进入的下一状态 next_state
        state, action, reward, next_state = (
            sample["state"],
            sample["action"],
            sample["reward"],
            sample["next_state"],
        )

        # 核心逻辑:计算时序差分误差 (TD Error),即"真实看到的总甜头 (眼前奖励 + 预演下一步的最大预期) - 以前的旧记忆"
        delta = reward + self.gamma * np.max(self.Q[next_state, :]) - self.Q[state, action]

        # 运用 Q-Learning 更新公式:新记忆 = 旧记忆 + 学习率 * 现实与想象的落差 (delta)
        self.Q[state, action] += self.learning_rate * delta

        return  # 当前单步的 Q 表价值更新完成,返回

train_workflow.py

python 复制代码
import time  # 导入time模块,用于控制定时日志上报和模型保存的时间戳
import math  # 导入math模块,用于计算基于指数衰减的自适应探索率(epsilon)
import os    # 导入os模块,用于获取当前运行进程的PID,方便区分监控数据

# 导入框架通用的数据帧结构、Q-learning所需的样本格式化函数和奖励塑形函数
from common_python.utils.common_func import Frame
from agent_q_learning.feature.definition import sample_process, reward_shaping
from tools.train_env_conf_validate import read_usr_conf  # 导入配置读取工具,用于加载环境参数
from tools.metrics_utils import get_training_metrics  # 导入指标拉取工具,用于获取系统级的训练数据
from common_python.utils.workflow_disaster_recovery import handle_disaster_recovery  # 导入环境容灾恢复组件

def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):  # 定义 Q-Learning 智能体的标准训练工作流函数
    try:  # 开启异常捕获,确保交互和训练过程中的未知错误能被安全记录而不致死机
        usr_conf = read_usr_conf("agent_q_learning/conf/train_env_conf.toml", logger)  # 读取并加载指定的 TOML 环境配置文件
        if usr_conf is None:  # 如果配置文件读取失败或路径不正确
            logger.error("usr_conf is None, please check agent_q_learning/conf/train_env_conf.toml")  # 记录致命错误日志
            return  # 配置文件损坏,直接中断并退出工作流

        env, agent = envs[0], agents[0]  # 从实例列表中分别取出当前使用的第一个环境对象和第一个智能体对象
        EPISODES = 10000  # 设定最大训练总回合数限制为 10000 局游戏

        monitor_data = {  # 初始化外部监控看板的数据字典
            "reward": 0,  # 设定初始的近期平均回报值为 0
        }
        last_report_monitor_time = 0  # 初始化上一次向网页监控大盘上报数据的时间戳
        last_get_training_metrics_time = 0  # 初始化上一次拉取底层训练性能指标的时间戳

        logger.info("Start Training...")  # 打印日志,宣告 Q-Learning 强化学习训练正式拉开帷幕
        start_time = time.time()  # 记录训练刚开始时的绝对时间,用于最终统计总耗时
        last_save_model_time = start_time  # 将上一次保存模型的时间戳初始化为当前训练启动时刻

        total_reward = 0  # 初始化近期累计奖励总和(在每个 15 秒上报周期结束后会清零)
        episode_count = 0  # 初始化近期完成的对局计数器(在每个 15 秒上报周期结束后会清零)
        win_count = 0  # 初始化智能体通关或获胜的绝对总局数,用于动态计算全局胜率

        for episode in range(EPISODES):  # 开始进入按回合(episode)进行的宏观大循环
            if time.time() - last_get_training_metrics_time > 15:  # 判断距离上次拉取系统指标是否已经超过了 15 秒
                last_get_training_metrics_time = time.time()  # 更新最近一次拉取系统指标的时间戳为当前时间
                training_metrics = get_training_metrics()  # 调用外部接口获取当前的训练框架系统级指标
                if training_metrics:  # 如果成功拿到了底层的系统运行指标
                    logger.info(f"training_metrics is {training_metrics}")  # 将这些底层指标打印到日志中供开发者调试

            env_obs = env.reset(usr_conf=usr_conf)  # 重置环境到初始关卡状态,并传入用户配置,获取本局游戏的初始环境观测

            if handle_disaster_recovery(env_obs, logger):  # 检查刚初始化的观测是否异常(如断线、服务器挂掉),进行容灾处理
                continue  # 如果触发了容灾挂起或重启,放弃当前对局,直接跳过并重新尝试开启下一局游戏

            obs_data = agent.observation_process(env_obs)  # 智能体对初始观测进行预处理(例如将2D坐标压缩映射为1D的状态索引)

            done = False  # 初始化单局游戏内部的结束标志位为 False
            agent.epsilon = 1.0  # 在每局游戏开始时,将当前回合的初始探索率重置为 1.0(最大化随机探索)

            while not done:  # 进入单局游戏内部的单步(Step)交互更新循环,直到游戏 done 才会跳出
                # 核心机制:让探索率 epsilon 随着训练回合数的增加,呈指数衰减。训练局数越多越倾向于稳妥利用,保底不低于 0.1
                agent.epsilon = max(0.1, agent.epsilon * math.exp(-(1 / EPISODES) * episode))

                act_data = agent.predict(list_obs_data=[obs_data])  # 智能体根据当前的 1D 状态和上面的 epsilon 探索率,通过 ε-贪心策略预测出动作数据包
                act_data = act_data[0]  # 从批量预测结果列表中取出当前步的唯一动作数据对象

                current_action = agent.action_process(act_data)  # 将智能体内部的动作数据包解包转换为环境物理引擎可直接执行的数字动作

                next_env_reward, next_env_obs = env.step(current_action)  # 将动作下发给环境执行,环境前推一步,返回即时奖励和下一步的新观测

                if handle_disaster_recovery(next_env_obs, logger):  # 交互一步后立即检查新观测是否触发服务器异常或网络波动
                    break  # 如果触发环境容灾事件,直接强行中断并放弃当前这局游戏

                terminated, truncated = next_env_obs["terminated"], next_env_obs["truncated"]  # 解析环境返回的结局标志:正常结束(死亡/通关) 或 超时截断

                next_obs_data = agent.observation_process(next_env_obs)  # 智能体继续将新返回的原始观测数据处理为下一步的 1D 新状态特征

                reward = reward_shaping(next_env_reward, next_env_obs)  # 调用自定义的奖励塑形函数,注入步数惩罚等干预信号,重新计算该步奖励

                done = terminated or truncated  # 只要触发了正常结束或超时截断,都认为当前这局游戏已经完结(done)
                if terminated:  # 如果是达成了正常的通关大结局(非超时截断)
                    win_count += 1  # 说明智能体本局获胜,将全局总胜利局数计数器加 1

                # 【Q-Learning单步核心】将当前状态特征、采取的动作、刚拿到的塑形奖励、以及下一步的新状态特征封装成一个单步 Frame 数据帧
                sample = Frame(
                    state=obs_data.feature,
                    action=current_action,
                    reward=reward,
                    next_state=next_obs_data.feature,
                )

                sample = sample_process([sample])  # 调用处理函数,将这个单步 Frame 数据帧转换为算法 learn 接口直接接收的字典格式样本
                agent.learn(sample)  # 【即时查表优化】不需要等游戏打完!立刻将单步样本送入大白话公式进行时序差分(TD Error)计算并更新 Q 表

                total_reward += reward  # 将当前单步获得的奖励累加到近期总奖励中,用于后续计算平均分
                obs_data = next_obs_data  # 状态无缝转移:将下一步的新状态特征赋给当前状态变量,为下一个 time step 做好准备

            episode_count += 1  # 成功跳出单局 while 循环说明一局游戏彻底打完,近期完成的回合计数器加 1
            now = time.time()  # 获取当前时刻的绝对时间戳

            is_converged = win_count / (episode + 1) > 0.9 and episode > 100  # 判定收敛规则:如果总对局数超过 100 局且全局历史总胜率稳定在 90% 以上

            if is_converged or now - last_report_monitor_time > 15:  # 如果已经达到了收敛标准,或者距离上一次监控汇报已经过去了 15 秒
                avg_reward = total_reward / episode_count  # 用当前周期的总得分除以这 15 秒内打完的局数,算出一个极其准确的近期平均回报
                logger.info(f"Episode: {episode + 1}, Avg Reward: {avg_reward}")  # 打印当前进行到的总局数以及近期对局的平均回报到日志中
                logger.info(f"Training Win Rate: {win_count / (episode + 1)}")  # 打印从训练启动至今的智能体全局历史总胜率
                monitor_data["reward"] = avg_reward  # 将这个近期平均回报写入待上报给外部监控大盘的指标字典中
                if monitor:  # 如果外部实例化并传入了有效的网页大盘监控组件
                    monitor.put_data({os.getpid(): monitor_data})  # 以当前进程的 PID 为 key,将监控数据推送至大盘的消息展示队列中

                total_reward = 0  # 监控周期汇报完毕,将周期累加奖励总和清零,准备统计下一个 15 秒
                episode_count = 0  # 将周期对局计数器清零
                last_report_monitor_time = now  # 更新上一次向监控大盘汇报数据的时间戳为当前时刻

            if is_converged:  # 再次确认刚才的收敛标志位
                logger.info(f"Training Converged at Episode: {episode + 1}")  # 在日志中高调宣布:模型已完美收敛,提前达成训练目的
                break  # 直接彻底切断并跳出最外层的 EPISODES 大循环,优雅地提前结束整个训练大工作流

            if now - last_save_model_time > 300:  # 检查距离上一次将策略和 Q 表持久化存盘是否已经过去了 300 秒(5分钟定时存档)
                logger.info(f"Saving Model at Episode: {episode + 1}")  # 打印定时存档日志,提示正在写入磁盘
                agent.save_model()  # 调用智能体的模型保存接口,将当前的 Q 表序列化为文件,防止因突发停电或进程被杀导致断训
                last_save_model_time = now  # 更新最近一次保存模型的时间戳为当前时刻

        end_time = time.time()  # 整个训练大循环彻底结束(跑满10000局或中途提前收敛),记录最终画上句号的绝对时间戳
        logger.info(f"Training Time for {episode + 1} episodes: {end_time - start_time} s")  # 统计并输出本次强化学习训练所消耗的总秒数
        agent.episodes = episode + 1  # 将智能体实际经历并跑完的对局总数写入其对象属性中备份

        agent.save_model()  # 训练功德圆满,执行最后一次最终模型的强制落盘,确保最新、最聪明的 Q 表策略被完美固化下来

    except Exception as e:  # 捕获上述庞大工作流在运行时抛出的任何未知的崩溃或严重代码异常
        raise RuntimeError(f"workflow error: {e}")  # 将捕获的底层异常统一包装为 RuntimeError 向上抛出,并附带具体的错误描述以供排查

agent.py

python 复制代码
import numpy as np  # 导入 numpy 库,用于矩阵查表和高效数学运算
from kaiwudrl.interface.agent import BaseAgent  # 从框架中导入智能体基类 BaseAgent
from common_python.utils.common_func import create_cls  # 导入用于动态创建数据结构类的辅助函数
from agent_q_learning.conf.conf import Config  # 导入包含了环境维度与超参数的配置类
from agent_q_learning.algorithm.algorithm import Algorithm  # 导入上文定义好的 Q-Learning 核心算法类

ObsData = create_cls("ObsData", feature=None)  # 动态创建一个名为 ObsData 的类,用于封装处理后的状态特征
ActData = create_cls("ActData", act=None)  # 动态创建一个名为 ActData 的类,用于封装智能体输出的动作数据

class Agent(BaseAgent):  # 定义 Q-Learning 智能体类,继承自 BaseAgent 基类
    def __init__(self, agent_type="player", device=None, logger=None, monitor=None) -> None:  # 初始化方法,接收类型、设备、日志和监控实例
        self.logger = logger  # 保存日志记录器实例,以便在类内部各处打印运行日志

        self.state_size = Config.STATE_SIZE  # 从配置类中读取状态空间总大小
        self.action_size = Config.ACTION_SIZE  # 从配置类中读取动作空间总大小
        self.learning_rate = Config.LEARNING_RATE  # 从配置类中读取时间差分更新的学习率(Alpha)
        self.gamma = Config.GAMMA  # 从配置类中读取折扣因子(Gamma)
        self.epsilon = Config.EPSILON  # 从配置类中读取初始随机探索率(Epsilon)
        self.episodes = Config.EPISODES  # 从配置类中读取最大训练局数限制
        
        # 实例化 Q-Learning 算法对象,传入折扣因子、学习率和状态/动作空间的大小
        self.algorithm = Algorithm(self.gamma, self.learning_rate, self.state_size, self.action_size)

        super().__init__(agent_type, device, logger, monitor)  # 调用父类 BaseAgent 的初始化方法,完成底层框架的对接

    def predict(self, list_obs_data):  # 定义预测方法,用于在训练阶段根据当前观测生成下一步动作
        state = list_obs_data[0].feature  # 从批量观测数据中取出第一条(当前步)的状态特征编码索引
        action = self._epsilon_greedy(state=state, epsilon=self.epsilon)  # 调用内部的 ε-贪心策略函数,决定是探索还是查表利用

        return [ActData(act=action)]  # 将计算出的动作整数包装成 ActData 对象,并以列表形式返回给训练工作流

    def exploit(self, env_obs):  # 定义利用方法,通常用于模型评估或正式比赛阶段(完全不包含随机探索机制)
        obs_data = self.observation_process(env_obs)  # 对环境传来的原始观测数据进行预处理和特征编码
        state = obs_data.feature  # 取出处理好的 1D 状态编码特征
        act_data = ActData(act=int(np.argmax(self.algorithm.Q[state, :])))  # 100% 贪心地去 Q 表对应的行里查分,找出评分最高的动作
        action = self.action_process(act_data)  # 将 ActData 动作对象解包为环境可直接执行的具体动作数值
        return action  # 返回要执行的绝对最优动作

    def _epsilon_greedy(self, state, epsilon=0.1):  # 定义 ε-贪心算法的具体实现,用于平衡探索(Exploration)与利用(Exploitation)
        if np.random.rand() <= epsilon:  # 生成一个 0~1 的随机数,如果小于等于设定的探索率 epsilon
            action = int(np.random.randint(0, self.action_size))  # 则触发随机探索:等概率在所有可选动作中随机抽一个
        else:  # 否则(概率为 1-epsilon)触发利用,即利用现有的 Q 表经验做出最优决策
            # 细节:如果发现 Q 表当前这一行所有动作的分数都完全相等(即游戏初期没探索过、分全为初始值)
            if np.all(self.algorithm.Q[state, :] == self.algorithm.Q[state, 0]):
                action = int(np.random.randint(0, self.action_size))  # 此时盲目取 argmax 永远会选第一个动作,为了打破平局,随机选一个
            else:  # 如果分数有高有低,已经分出优劣
                action = int(np.argmax(self.algorithm.Q[state, :]))  # 查表找到当前状态下评分最高的动作索引

        return action  # 返回最终决定的动作索引

    def learn(self, list_sample_data):  # 定义学习函数,接收单步时序差分的交互数据样本
        return self.algorithm.learn(list_sample_data)  # 直接透传给底层 Q-Learning 算法对象的 learn 方法,进行单步 Q 表实时修正

    def observation_process(self, env_obs):  # 定义观测预处理函数,将环境复杂的地图和物件字典状态压缩为一个单一的特征编码
        obs = env_obs["observation"]  # 从环境观测大字典中提取核心的 observation 数据段
        pos = [obs["frame_state"]["hero"]["pos"]["x"], obs["frame_state"]["hero"]["pos"]["z"]]  # 提取智能体在 2D 游戏地图上的绝对坐标 (x, z)

        pos_feature = int(pos[0] * 64 + pos[1])  # 地图降维:将 2D 坐标平面展开压缩为一个 1D 的一维数组索引

        treasure_status = [0] * 10  # 初始化一个长度为 10 的列表,用来记录地图上 10 个宝箱每一个的开闭状态
        for organ in obs["frame_state"]["organs"]:  # 循环遍历环境中所有的机关/物件数据
            if organ["sub_type"] == 1:  # 如果当前机关子类型为 1(代表这是一个可以互动的宝箱)
                treasure_status[organ["config_id"]] = int(organ["status"])  # 获取该宝箱的 ID 并将其开/闭状态(0或1)存入列表对应位置

        # 核心算法高阶拼接:将一维位置索引和 10 位二进制开闭状态结合起来,组合成一个全宇宙唯一的 1D 整数状态特征
        feature = int(1024 * pos_feature + sum([treasure_status[i] * (2**i) for i in range(10)]))

        return ObsData(feature=feature)  # 将最终拼接生成的超级状态特征封装成 ObsData 对象后返回

    def action_process(self, act_data):  # 定义动作处理函数,完成动作包到环境输入的转换
        return act_data.act  # 从算法预测生成的 ActData 包装类中,剥离出纯粹的数值动作指令返回给环境

    def save_model(self, path=None, id="1"):  # 定义模型保存函数,定期将训练成果(Q表)持久化落盘
        model_file_path = f"{path}/model.ckpt-{str(id)}.npy"  # 拼接出保存文件的完整绝对路径,遵循框架的 model.ckpt-id 命名规范
        np.save(model_file_path, self.algorithm.Q)  # 使用 numpy 的 save 方法,将整个 Q 表矩阵序列化存入指定的 .npy 数据文件
        self.logger.info(f"save model {model_file_path} successfully")  # 打印保存成功的日志信息

    def load_model(self, path=None, id="1"):  # 定义模型加载函数,用于断点续训或直接拉取模型进行效果评估
        model_file_path = f"{path}/model.ckpt-{str(id)}.npy"  # 拼接出需要读取的模型文件完整绝对路径
        try:  # 开启异常捕获,防止因为找不到模型文件而导致程序意外中止
            self.algorithm.Q = np.load(model_file_path)  # 从文件中加载 NumPy 数组并覆盖当前的算法 Q 表,恢复智能体的"所有记忆"
            self.logger.info(f"load model {model_file_path} successfully")  # 打印加载成功的日志信息
        except FileNotFoundError:  # 如果捕获到文件不存在异常
            self.logger.info(f"File {model_file_path} not found")  # 打印未找到模型文件的严重日志提示
            exit(1)  # 找不到模型无法进行评估,以异常状态码(1)强行退出系统进程
相关推荐
还不秃顶的计科生1 小时前
codex配置自动化visio/ppt
机器学习·visio
卡次卡次11 小时前
vibecoding起步之Claude Code的skills是什么,里面有什么文件,以ppt的一个skills举例
人工智能·opencv·powerpoint
AI服务老曹1 小时前
解耦异构算力:基于 Docker 与 GB28181/RTSP 的边缘计算 AI 视频管理平台架构设计与源码交付实践
人工智能·docker·边缘计算
小饕1 小时前
RAG 实战:文本切块(Text Chunking)从入门到精通
人工智能
多年小白1 小时前
【周末消息】2026年5月30日-6月1日
大数据·人工智能·深度学习·机器学习·金融
AI导出鸭PC端1 小时前
智谱清言清除符号:当LLM输出遭遇“结构性失序”,一份关于AI导出鸭的工程化测评
人工智能
weixin_468466851 小时前
Prometheus监控服务部署与实战指南
服务器·后端·python·docker·自动化·prometheus
花酒锄作田2 小时前
[Python]标准库argparse解析命令行参数使用介绍
python