强化学习笔记——4策略迭代、值迭代、TD算法

基于策略迭代的贝尔曼方程和基于值迭代的贝尔曼方程,关系还是不太理解

首先梳理一下:

通过贝尔曼方程将强化学习转化为值迭代和策略迭代两种问题

求解上述两种贝尔曼方程有三种方法:DP(有模型),MC(无模型),TD(DP和MC结合)
这三种只是方法,既可以用于求值迭代也可以用于求解策略迭代


我总结就是:值迭代方法通过求最优价值函数,可以间接得到 最优策略

策略迭代是:初始化一个随机策略,然后按照当前策略迭代价值函数
再进行策略改进,二者交替直到策略基本不发生变化。



上述就是贝尔曼最优公式的过程,求解最优的策略
详细见
V(s)求解举例

直接看值迭代伪代码:

  1. 遍历每个状态S,对每个状态S遍历所有动作A
  2. 计算Q值
  3. 对于每个状态S选择Q值最大的那个动作作为更新的策略 ,最大Q值作为新的V(s)

策略迭代:分两步policy Evalution策略评估(就是求值函数),policy improvement(策略更新)

  1. 策略评估中,如何通过求解贝尔曼方程得到值函数?
  2. 策略更新中,为什么新策略Πk+1就比原策略Πk好?
  3. 为什么策略迭代可以找到最优策略?
  4. 值迭代和策略迭代直接什么关系?

policy Evalution本身也是个迭代要循环


Q4:策略迭代用到了值迭代的结果,是基于值收敛的。

伪代码:

  1. 进入PolicyEvaluation,目的求解收敛的VΠk。对于每个状态S迭代。
  2. 计算每个状态S下每个动作A的Q值,选择最大的作为策略Πk+1
  3. 不断重复(一个1,2步骤表示一回合 )

对比两个伪代码发现:值迭代的值函数计算不强调某策略(Vk),因为它遍历所有状态的所有动作策略,然后计算Q值选最优动作为策略
策略迭代:计算值函数强调是某一策略(VΠk),在某一个具体策略下求出值函数,然后再遍历所有状态的所有动作,然后计算Q值选最优动作为更新的策略

=======================================================================

上述两方法,不可避免要求Q值。

蒙特卡洛方法,通过无模型方法求解Q值

从一个s,a出发走很多个回合计算回报平局值,即为Q(s,a)

有些改进 蒙特卡洛方法不用走很多个回合计算回报平局值,只一个回合得到回报,然后作为Q

TD算法: 无模型求解贝尔曼方程

包含一系列:TD0,SARSA,Qlearning,DQN

的都是求解贝尔曼公式:但有的求解基于值函数刻画的贝尔曼公式,有的求解基于动作价值函数刻画的贝尔曼公式

它结合了动态规划(DP)和蒙特卡洛方法(MC)的优点

基于表格的TD算法总结:

TD算法只是相当于做策略评估,不负责policy improvement

cpp 复制代码
实现SARSA和Qlearning算法
import numpy as np
from collections import defaultdict

class QLearning:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.env = env
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率
        
        # 初始化Q表
        self.Q = defaultdict(lambda: np.zeros(len(env.action_space))) #用于创建一个长度为 len(env.action_space) 的全零数组。
        
    def choose_action(self, state):
        if np.random.rand() < self.epsilon:
            # 随机选择动作索引
            action_idx = np.random.choice(len(self.env.action_space))
            return self.env.action_space[action_idx]  # 探索
        else:
            # 选择Q值最大的动作
            action_idx = np.argmax(self.Q[state])
            return self.env.action_space[action_idx]  # 利用
        
    def learn(self, state, action, reward, next_state, done):
        # 将状态转换为可哈希的键

        next_state_key = next_state
        current_q = self.Q[state][self.env.action_space.index(action)]
        max_next_q = np.max(self.Q[next_state_key])
        
        # Q-learning更新公式
        new_q = current_q + self.alpha * (reward + self.gamma * max_next_q - current_q)
        self.Q[state][self.env.action_space.index(action)] = new_q

class SARSA:
    def __init__(self, env, alpha=0.1, gamma=0.99, epsilon=0.1):
        self.env = env
        self.alpha = alpha  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率
        
        # 初始化Q表
        self.Q = defaultdict(lambda: np.zeros(len(env.action_space)))
        
    def choose_action(self, state):
        
        if np.random.rand() < self.epsilon: #以概率 ϵ 随机选择动作
            # 随机选择动作索引
            action_idx = np.random.choice(len(self.env.action_space))
            return self.env.action_space[action_idx]  # 探索
        else:
            # 选择Q值最大的动作
            action_idx = np.argmax(self.Q[state])
            return self.env.action_space[action_idx]  # action_idx动作索引,返回具体动作(0, 1), (1, 0), (0, -1), (-1, 0), (0, 0)
            
    def learn(self, state, action, reward, next_state, next_action, done):
            
        next_state_key = next_state
        current_q = self.Q[state][self.env.action_space.index(action)]
        next_q = self.Q[next_state_key][self.env.action_space.index(next_action)]
         # SARSA更新公式
        new_q = current_q + self.alpha * (reward + self.gamma * next_q - current_q) #一步TD更新
        
        # 更新Q表
        self.Q[state][self.env.action_space.index(action)] = new_q

上述使用Q表每次记录下来Q值,下次(s,a)可以直接读取Q值

还有一种方法是用函数、神经网络计算Q值,输入(s,a)输出Q,然后梯度下降优化函数的参数,使得Q值计算更准确。

相关推荐
emmmmXxxy3 小时前
leetcode刷题-贪心03
算法·leetcode·职场和发展
Oracle_6664 小时前
C基础算法与实现
c语言·算法
h^hh6 小时前
向下调整算法(详解)c++
数据结构·c++·算法
能源革命6 小时前
负荷预测算法模型
算法·能源
海绵波波1076 小时前
350.两个数组的交集 ②
数据结构·算法·leetcode
Evand J7 小时前
基于改进的强跟踪技术的扩展Consider Kalman滤波算法在无人机导航系统中的应用研究
算法·无人机
王老师青少年编程7 小时前
gesp(C++六级)(4)洛谷:B3874:[GESP202309 六级] 小杨的握手问题
开发语言·c++·算法·gesp·csp·信奥赛
h^hh7 小时前
priority_queue的创建_结构体类型(重载小于运算符)c++
开发语言·数据结构·c++·算法
Melancholy 啊8 小时前
细说机器学习算法之ROC曲线用于模型评估
人工智能·python·算法·机器学习·数据挖掘