24/8/17算法笔记 CQL算法离线学习

离线学习:不需要更新数据

CQL(Conservative Q-Learning)算法是一种用于离线强化学习的方法,它通过学习一个保守的Q函数来解决标准离线RL方法可能由于数据集和学习到的策略之间的分布偏移而导致的过高估计问题 。CQL算法的核心思想是在Q值的基础上增加一个正则化项(regularizer),从而得到真实动作值函数的下界估计。这种方法在理论上被证明可以产生当前策略的真实值下界,并且可以进行策略评估和策略提升的过程 。

CQL算法通过修改值函数的备份方式,添加正则化项来实现保守性。在优化过程中,CQL旨在找到一个Q函数,该函数在给定策略下的期望值低于其真实值。这通过在Q学习的目标函数中添加一个惩罚项来实现,该惩罚项限制了策略π下Q函数的期望值不能偏离数据分布Q函数的期望值 。

CQL算法的实现相对简单,只需要在现有的深度Q学习和行动者-评论家实现的基础上添加少量代码。在实验中,CQL在多个领域和数据集上的表现优于现有的离线强化学习方法,尤其是在学习复杂和多模态数据分布时,通常可以使学习策略获得2到5倍的最终回报 。

此外,CQL算法的一个关键优势是它提供了一种有效的解决方案,可以在不与环境进行额外交互的情况下,利用先前收集的静态数据集学习有效的策略。这使得CQL在自动驾驶和医疗机器人等领域具有潜在的应用价值,这些领域中与环境的交互次数在成本和风险方面都是有限的 。

总的来说,CQL算法通过其保守的Q函数估计和正则化策略,为离线强化学习领域提供了一种有效的策略学习框架,并在理论和实践上都显示出了其有效性

复制代码
import gym
from matplotlib import pyplot as plt
import numpy as np
import random
%matplotlib inline
#创建环境
env = gym.make('Pendulum-v1')
env.reset()

#打印游戏
def show():
    plt.imshow(env.render(mode='rgb_array'))
    plt.show()

定义sac模型,代码略http://t.csdnimg.cn/ic2HX

定义teacher模型

复制代码
#定义teacher模型
teacher = SAC()

teacher.train(
    torch.tandn(5,3),
    torch.randn(5,1),
    torch.randn(5,1),
    torch.randn(5,3),
    torch.zeros(5,1).long(),
)

定义Data类

复制代码
#样本池
datas = []

#向样本池中添加N条数据,删除M条最古老的数据
def update_data():
    #初始化游戏
    state = env.reset()
    
    #玩到游戏结束为止
    over = False
    while not over:
        #根据当前状态得到一个动作
        action = get_action(state)
        
        #执行当作,得到反馈
        next_state,reward,over, _ = env.step([action])
        
        #记录数据样本
        datas.append((states,action,reward,next_state,over))
        
        #更新游戏状态,开始下一个当作
        state = next_state
    #数据上限,超出时从最古老的开始删除
    while len(datas)>10000:
        datas.pop(0)
        
#获取一批数据样本
def get_sample():
    samples = random.sample(datas,64)
    #[b,4]
    state = torch.FloatTensor([i[0]for i in samples]).reshape(-1,3)
    #[b,1]
    action = torch.LongTensor([i[1]for i in samples]).reshape(-1,1)
    #[b,1]
    reward = torch.FloatTensor([i[2]for i in samples]).reshape(-1,1)
    #[b,4]
    next_state = torch.FloatTensor([i[3]for i in samples]).reshape(-1,3)
    #[b,1]
    over = torch.LongTensor([i[4]for i in samples]).reshape(-1,1)

    return state,action,reward,next_state,over

state,action,reward,next_state,over=get_sample()

state[:5],action[:5],reward[:5],next_state[:5],over[:5]

data = Data()
data.update_data(teacher),data.get_sample()

训练teacher模型

复制代码
#训练teacher模型
for epoch in range(100):
    #更新N条数据
    datat.update_data(teacher)
    
    #每次更新过数据后,学习N次
    for i in range(200):
        teacher.train(*data.get_sample())
        
    if epoch%10==0:
        test_result = sum([teacher.test(play=False)for _ in range(10)])/10
        print(epoch,test_result)

定义CQL模型

复制代码
class CQL(SAC):
    def __init__(self):
        super().__init__()
    def _get_loss_value(self,model_value,target,state,action,next_state):
        #计算value
        value = model_value(state,action)
        
        #计算loss,value的目标是要贴近target
        loss_value = self.loss_fn(value,tarfet)
        """以上与SAC相同,以下是CQL部分"""
        
        #把state复制5彼遍
        state = state.unsqueeze(dim=1)
        state = state.repeat(1,5,1).reshape(-1,3)
        #把next_state复制5遍
        next_state = next_state.unsqueeze(1)
        next_state = next_state.repeat(1,5,1).reshape(-1,3)
        
        #随机一批动作,数量是数据量的5倍,值域在-1到1之间
        rand_action = torch.empty([len(state),1]).uniform_(-1,1)
        
        #计算state的动作和熵
        curr_action,next_entropy = self..mdoel_action(next_state)
        #计算三方动作的value
        value_rand = model_value(state,rand_action).reshape(-1,5,1)
        value_curr = model_value(state,curr_action).reshape(-1,5,1)
        value_next = model_value(state,next_action).reshape(-1,5,1)

        curr_entropy = curr_entropy.detach().reshape(-1,5,1)
        next_entropy = next_entropy.detach().reshape(-1,5,1)
        
        #三份value分别减去他们的熵
        value_rand -=mat.log(0.5)
        value_curr -=curr_entropy
        value_next -=next_entropy
        
        #拼合三份value
        value_cat = torch.cat([value_rand,value_curr,value_next],dim=1)
        
        #等价t.logsumexp(dim=1),t.exp().sum(dim=1).log()
        loss_cat = torch.logsumexp(value_cat,dim =1).mean()
        
        #在原本的loss上增加上这一部分
        loss_value += 5.0*(loss_cat - value.mean())
        """差异到此为止"""

学生模型

复制代码
student = CQL()
student.train(
    torch.randn(5,3),
    torch.randn(5,1),
    torch.randn(5,1),
    torch.randn(5,3),
    torch.zeros(5,1)long(),
)

离线训练,训练过程中完全不更新数据

复制代码
#训练N次,训练过程中不需要更新数据
for i in range(50000):
    #采样一批数据
    student.train(*data.get_sample())
    
    if i%2000 ==0:
        test_result = sum([student.test(play = False) for _ in range(10)])
        print(i,test_result)
相关推荐
ManageEngineITSM3 分钟前
IT 服务自动化的时代:让效率与体验共进
运维·数据库·人工智能·自动化·itsm·工单系统
总有刁民想爱朕ha18 分钟前
AI大模型学习(17)python-flask AI大模型和图片处理工具的从一张图到多平台适配的简单方法
人工智能·python·学习·电商图片处理
浅川.2522 分钟前
xtuoj string
开发语言·c++·算法
302AI33 分钟前
体验升级而非颠覆,API成本直降75%:DeepSeek-V3.2-Exp评测
人工智能·llm·deepseek
韩非37 分钟前
if 语句对程序性能的影响
算法·架构
新智元37 分钟前
老黄押宝「美版 DeepSeek」!谷歌天才叛将创业,一夜吸金 20 亿美元
人工智能·openai
新智元39 分钟前
刚刚,全球首个 GB300 巨兽救场!一年烧光 70 亿,OpenAI 内斗 GPU 惨烈
人工智能·openai
用户9163574409540 分钟前
LeetCode热题100——15.三数之和
javascript·算法
Cathy Bryant43 分钟前
球极平面投影
经验分享·笔记·数学建模
小虎鲸001 小时前
PyTorch的安装与使用
人工智能·pytorch·python·深度学习