【强化学习】关于PPO收敛问题

文章目录


前言

写了好几个版本的PPO,很容易出现的问题。


1.注意两处维度须一致

  • 1.注意两处维度须一致

    markdown 复制代码
    1.surr1 = ratio * adv
    2.returns = adv + value
    又由于ratio = torch.exp(new_log_probs - old_log_probs)得到
    所以须保证log_prob 与 adv的维度一致;adv与value的维度一致
    即:log_prob与adv与value的维度一致
    
    而elegentRL 有此代码,推荐其代码库
    即: assert logprobs.shape == advantages.shape == reward_sums.shape == (buffer_size,)
     
     不然会出现 bx1 + b = bxb 的结构 ,这样就不对了 

2.梯度问题

  • 2.梯度问题

    即:其实就是求导问题:一般来说,目标值不需要梯度 例子,目标值为5, x-5求导为1 ,否则 对x - y 求导 就会报错。

    这里就是old_log,adv, return 不需要梯度。

    与adv 计算相关的value值 也可以有梯度,只要最终adv没有梯度就行。

    特别的torch.tensor(advantage_list, dtype=torch.float) (--- hand-on RL) 这种也是去掉梯度的方式,所以一般buffer里的数据都是无梯度的。

3.提高收敛因素

  • 3.提高收敛因素

    1.horizon 即每次更新的步数设置为2048时极大提高收敛效率

    2.state 归一到 -1~1

    3.action只有在环境需要限制大小的值时才要clip,在训练时的action 不要clip 保证与训练时计算的分布相同

其他

  • 其他

    torch.cat([values, last_value.unsqueeze(0)]) 不会起移除梯度,加梯度作用

    主要作用是 with torch.no_grad():

    markdown 复制代码
    with torch.no_grad():
         values = self.value(batch_states).squeeze(-1)
    
    print(values.requires_grad,1)
    values = torch.cat([values, last_value.unsqueeze(0)]) 
    print(values.requires_grad)
    
    False 1
    True
    markdown 复制代码
    with torch.no_grad():
         values = self.value(batch_states).squeeze(-1)
    
         print(values.requires_grad,1)
         values = torch.cat([values, last_value.unsqueeze(0)]) 
         print(values.requires_grad)
    
    False 1
    False
相关推荐
江上鹤.1481 小时前
Day 28 复习日
人工智能·python·机器学习
Hello.Reader1 小时前
从 0 到 1 跑通第一个 Flink ML 示例
大数据·python·flink
nwsuaf_huasir1 小时前
Elsevier投稿系统编译latex文件参考文献显示为问号
深度学习
DFT计算杂谈1 小时前
免注册下载各个版本Anaconda3/Miniconda3
python
oliveray1 小时前
动手搭建Flamingo(VQA)
人工智能·深度学习·vlms
进阶的小蜉蝣1 小时前
[Machine Learning] 机器学习中的Collate
人工智能·机器学习
虹科网络安全1 小时前
艾体宝干货 | Redis Python 开发系列#6 缓存、分布式锁与队列架构
redis·python·缓存
猎人everest1 小时前
Django Rest Framework (DRF) 核心知识体系梳理与深度讲解
后端·python·django
非著名架构师1 小时前
气象驱动的需求预测:零售企业如何通过气候数据分析实现库存精准控制
人工智能·深度学习·数据分析·transformer·风光功率预测·高精度天气预报数据