24/8/17算法笔记 CQL算法离线学习

离线学习:不需要更新数据

CQL(Conservative Q-Learning)算法是一种用于离线强化学习的方法,它通过学习一个保守的Q函数来解决标准离线RL方法可能由于数据集和学习到的策略之间的分布偏移而导致的过高估计问题 。CQL算法的核心思想是在Q值的基础上增加一个正则化项(regularizer),从而得到真实动作值函数的下界估计。这种方法在理论上被证明可以产生当前策略的真实值下界,并且可以进行策略评估和策略提升的过程 。

CQL算法通过修改值函数的备份方式,添加正则化项来实现保守性。在优化过程中,CQL旨在找到一个Q函数,该函数在给定策略下的期望值低于其真实值。这通过在Q学习的目标函数中添加一个惩罚项来实现,该惩罚项限制了策略π下Q函数的期望值不能偏离数据分布Q函数的期望值 。

CQL算法的实现相对简单,只需要在现有的深度Q学习和行动者-评论家实现的基础上添加少量代码。在实验中,CQL在多个领域和数据集上的表现优于现有的离线强化学习方法,尤其是在学习复杂和多模态数据分布时,通常可以使学习策略获得2到5倍的最终回报 。

此外,CQL算法的一个关键优势是它提供了一种有效的解决方案,可以在不与环境进行额外交互的情况下,利用先前收集的静态数据集学习有效的策略。这使得CQL在自动驾驶和医疗机器人等领域具有潜在的应用价值,这些领域中与环境的交互次数在成本和风险方面都是有限的 。

总的来说,CQL算法通过其保守的Q函数估计和正则化策略,为离线强化学习领域提供了一种有效的策略学习框架,并在理论和实践上都显示出了其有效性

import gym
from matplotlib import pyplot as plt
import numpy as np
import random
%matplotlib inline
#创建环境
env = gym.make('Pendulum-v1')
env.reset()

#打印游戏
def show():
    plt.imshow(env.render(mode='rgb_array'))
    plt.show()

定义sac模型,代码略http://t.csdnimg.cn/ic2HX

定义teacher模型

#定义teacher模型
teacher = SAC()

teacher.train(
    torch.tandn(5,3),
    torch.randn(5,1),
    torch.randn(5,1),
    torch.randn(5,3),
    torch.zeros(5,1).long(),
)

定义Data类

#样本池
datas = []

#向样本池中添加N条数据,删除M条最古老的数据
def update_data():
    #初始化游戏
    state = env.reset()
    
    #玩到游戏结束为止
    over = False
    while not over:
        #根据当前状态得到一个动作
        action = get_action(state)
        
        #执行当作,得到反馈
        next_state,reward,over, _ = env.step([action])
        
        #记录数据样本
        datas.append((states,action,reward,next_state,over))
        
        #更新游戏状态,开始下一个当作
        state = next_state
    #数据上限,超出时从最古老的开始删除
    while len(datas)>10000:
        datas.pop(0)
        
#获取一批数据样本
def get_sample():
    samples = random.sample(datas,64)
    #[b,4]
    state = torch.FloatTensor([i[0]for i in samples]).reshape(-1,3)
    #[b,1]
    action = torch.LongTensor([i[1]for i in samples]).reshape(-1,1)
    #[b,1]
    reward = torch.FloatTensor([i[2]for i in samples]).reshape(-1,1)
    #[b,4]
    next_state = torch.FloatTensor([i[3]for i in samples]).reshape(-1,3)
    #[b,1]
    over = torch.LongTensor([i[4]for i in samples]).reshape(-1,1)

    return state,action,reward,next_state,over

state,action,reward,next_state,over=get_sample()

state[:5],action[:5],reward[:5],next_state[:5],over[:5]

data = Data()
data.update_data(teacher),data.get_sample()

训练teacher模型

#训练teacher模型
for epoch in range(100):
    #更新N条数据
    datat.update_data(teacher)
    
    #每次更新过数据后,学习N次
    for i in range(200):
        teacher.train(*data.get_sample())
        
    if epoch%10==0:
        test_result = sum([teacher.test(play=False)for _ in range(10)])/10
        print(epoch,test_result)

定义CQL模型

class CQL(SAC):
    def __init__(self):
        super().__init__()
    def _get_loss_value(self,model_value,target,state,action,next_state):
        #计算value
        value = model_value(state,action)
        
        #计算loss,value的目标是要贴近target
        loss_value = self.loss_fn(value,tarfet)
        """以上与SAC相同,以下是CQL部分"""
        
        #把state复制5彼遍
        state = state.unsqueeze(dim=1)
        state = state.repeat(1,5,1).reshape(-1,3)
        #把next_state复制5遍
        next_state = next_state.unsqueeze(1)
        next_state = next_state.repeat(1,5,1).reshape(-1,3)
        
        #随机一批动作,数量是数据量的5倍,值域在-1到1之间
        rand_action = torch.empty([len(state),1]).uniform_(-1,1)
        
        #计算state的动作和熵
        curr_action,next_entropy = self..mdoel_action(next_state)
        #计算三方动作的value
        value_rand = model_value(state,rand_action).reshape(-1,5,1)
        value_curr = model_value(state,curr_action).reshape(-1,5,1)
        value_next = model_value(state,next_action).reshape(-1,5,1)

        curr_entropy = curr_entropy.detach().reshape(-1,5,1)
        next_entropy = next_entropy.detach().reshape(-1,5,1)
        
        #三份value分别减去他们的熵
        value_rand -=mat.log(0.5)
        value_curr -=curr_entropy
        value_next -=next_entropy
        
        #拼合三份value
        value_cat = torch.cat([value_rand,value_curr,value_next],dim=1)
        
        #等价t.logsumexp(dim=1),t.exp().sum(dim=1).log()
        loss_cat = torch.logsumexp(value_cat,dim =1).mean()
        
        #在原本的loss上增加上这一部分
        loss_value += 5.0*(loss_cat - value.mean())
        """差异到此为止"""

学生模型

student = CQL()
student.train(
    torch.randn(5,3),
    torch.randn(5,1),
    torch.randn(5,1),
    torch.randn(5,3),
    torch.zeros(5,1)long(),
)

离线训练,训练过程中完全不更新数据

#训练N次,训练过程中不需要更新数据
for i in range(50000):
    #采样一批数据
    student.train(*data.get_sample())
    
    if i%2000 ==0:
        test_result = sum([student.test(play = False) for _ in range(10)])
        print(i,test_result)
相关推荐
Power20246663 分钟前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k6 分钟前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫11 分钟前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法
沉下心来学鲁班26 分钟前
复现LLM:带你从零认识语言模型
人工智能·语言模型
数据猎手小k26 分钟前
AndroidLab:一个系统化的Android代理框架,包含操作环境和可复现的基准测试,支持大型语言模型和多模态模型。
android·人工智能·机器学习·语言模型
YRr YRr35 分钟前
深度学习:循环神经网络(RNN)详解
人工智能·rnn·深度学习
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-11-01
人工智能·深度学习·神经网络·算法·机器学习·语言模型·数据挖掘
多吃轻食1 小时前
大模型微调技术 --> 脉络
人工智能·深度学习·神经网络·自然语言处理·embedding
香菜大丸1 小时前
链表的归并排序
数据结构·算法·链表
jrrz08281 小时前
LeetCode 热题100(七)【链表】(1)
数据结构·c++·算法·leetcode·链表