强化学习笔记——4策略迭代、值迭代、TD算法

基于策略迭代的贝尔曼方程和基于值迭代的贝尔曼方程,关系还是不太理解

首先梳理一下:

通过贝尔曼方程将强化学习转化为值迭代和策略迭代两种问题

求解上述两种贝尔曼方程有三种方法:DP(有模型),MC(无模型),TD(DP和MC结合)
这三种只是方法,既可以用于求值迭代也可以用于求解策略迭代


我总结就是:值迭代方法通过求最优价值函数,可以间接得到 最优策略

策略迭代是:初始化一个随机策略,然后按照当前策略迭代价值函数
再进行策略改进,二者交替直到策略基本不发生变化。



上述就是贝尔曼最优公式的过程,求解最优的策略
详细见
V(s)求解举例

直接看值迭代伪代码:

  1. 遍历每个状态S,对每个状态S遍历所有动作A
  2. 计算Q值
  3. 对于每个状态S选择Q值最大的那个动作作为更新的策略 ,最大Q值作为新的V(s)

策略迭代:分两步policy Evalution策略评估(就是求值函数),policy improvement(策略更新)

  1. 策略评估中,如何通过求解贝尔曼方程得到值函数?
  2. 策略更新中,为什么新策略Πk+1就比原策略Πk好?
  3. 为什么策略迭代可以找到最优策略?
  4. 值迭代和策略迭代直接什么关系?

policy Evalution本身也是个迭代要循环


Q4:策略迭代用到了值迭代的结果,是基于值收敛的。

伪代码:

  1. 进入PolicyEvaluation,目的求解收敛的VΠk。对于每个状态S迭代。
  2. 计算每个状态S下每个动作A的Q值,选择最大的作为策略Πk+1
  3. 不断重复(一个1,2步骤表示一回合 )

对比两个伪代码发现:值迭代的值函数计算不强调某策略(Vk),因为它遍历所有状态的所有动作策略,然后计算Q值选最优动作为策略
策略迭代:计算值函数强调是某一策略(VΠk),在某一个具体策略下求出值函数,然后再遍历所有状态的所有动作,然后计算Q值选最优动作为更新的策略

=======================================================================

上述两方法,不可避免要求Q值。

蒙特卡洛方法,通过无模型方法求解Q值

从一个s,a出发走很多个回合计算回报平局值,即为Q(s,a)

有些改进 蒙特卡洛方法不用走很多个回合计算回报平局值,只一个回合得到回报,然后作为Q

TD算法: 无模型求解贝尔曼方程

包含一系列:TD0,SARSA,Qlearning,DQN

的都是求解贝尔曼公式:但有的求解基于值函数刻画的贝尔曼公式,有的求解基于动作价值函数刻画的贝尔曼公式

它结合了动态规划(DP)和蒙特卡洛方法(MC)的优点

基于表格的TD算法总结:

TD算法只是相当于做策略评估,不负责policy improvement

cpp 复制代码
实现SARSA和Qlearning算法
import numpy as np
from collections import defaultdict

class QLearning:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.env = env
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率
        
        # 初始化Q表
        self.Q = defaultdict(lambda: np.zeros(len(env.action_space))) #用于创建一个长度为 len(env.action_space) 的全零数组。
        
    def choose_action(self, state):
        if np.random.rand() < self.epsilon:
            # 随机选择动作索引
            action_idx = np.random.choice(len(self.env.action_space))
            return self.env.action_space[action_idx]  # 探索
        else:
            # 选择Q值最大的动作
            action_idx = np.argmax(self.Q[state])
            return self.env.action_space[action_idx]  # 利用
        
    def learn(self, state, action, reward, next_state, done):
        # 将状态转换为可哈希的键

        next_state_key = next_state
        current_q = self.Q[state][self.env.action_space.index(action)]
        max_next_q = np.max(self.Q[next_state_key])
        
        # Q-learning更新公式
        new_q = current_q + self.alpha * (reward + self.gamma * max_next_q - current_q)
        self.Q[state][self.env.action_space.index(action)] = new_q

class SARSA:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.env = env
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率
        
        # 初始化Q表
        self.Q = defaultdict(lambda: np.zeros(len(env.action_space)))
        
    def choose_action(self, state):
        
        if np.random.rand() < self.epsilon: #以概率 ϵ 随机选择动作
            # 随机选择动作索引
            action_idx = np.random.choice(len(self.env.action_space))
            return self.env.action_space[action_idx]  # 探索
        else:
            # 选择Q值最大的动作
            action_idx = np.argmax(self.Q[state])
            return self.env.action_space[action_idx]  # action_idx动作索引,返回具体动作(0, 1), (1, 0), (0, -1), (-1, 0), (0, 0)
            
    def learn(self, state, action, reward, next_state, next_action, done):
            
        next_state_key = next_state
        current_q = self.Q[state][self.env.action_space.index(action)]
        next_q = self.Q[next_state_key][self.env.action_space.index(next_action)]
         # SARSA更新公式
        new_q = current_q + self.alpha * (reward + self.gamma * next_q - current_q) #一步TD更新
        
        # 更新Q表
        self.Q[state][self.env.action_space.index(action)] = new_q

上述使用Q表每次记录下来Q值,下次(s,a)可以直接读取Q值

还有一种方法是用函数、神经网络计算Q值,输入(s,a)输出Q,然后梯度下降优化函数的参数,使得Q值计算更准确。

相关推荐
CYRUS_STUDIO1 小时前
常用加解密算法介绍
算法·安全·逆向
木亦汐丫3 小时前
【大模型系列篇】国产开源大模型DeepSeek-V3技术报告解析
sft·rl·mtp·mla·deepseekmoe·fp8 混合精度训练·dualpipe算法
weixin_535854223 小时前
快手,蓝禾,优博讯,三七互娱,顺丰,oppo,游卡,汤臣倍健,康冠科技,作业帮,高途教育25届春招内推
java·前端·python·算法·硬件工程
C_V_Better3 小时前
Java 导出 PDF 文件:从入门到实战
java·开发语言·算法·pdf
mit6.8244 小时前
[Lc(2)滑动窗口_1] 长度最小的数组 | 无重复字符的最长子串 | 最大连续1的个数 III | 将 x 减到 0 的最小操作数
数据结构·c++·算法·leetcode
zjoy_22334 小时前
【数据结构】什么是栈||栈的经典应用||分治递归||斐波那契问题和归并算法||递归实现||顺序栈和链栈的区分
java·c语言·开发语言·数据结构·c++·算法·排序算法
油泼辣子多加5 小时前
【华为OD机考】华为OD笔试真题解析(20)--投篮大赛
数据结构·算法·华为od
修己xj5 小时前
算法系列之数据结构-Huffman树
算法
CodeJourney.6 小时前
Deepseek助力思维导图与流程图制作:高效出图新选择
数据库·人工智能·算法
Liudef066 小时前
Stable Diffusion模型高清算法模型类详解
人工智能·算法·ai作画·stable diffusion