【FreeRL】我的深度学习库构建思想

文章目录


前言

代码实现在:https://github.com/wild-firefox/FreeRL

欢迎star

参考

目的是写出像TD3作者那样简单易懂的DRL代码,

由于参考了ElegentRL和Easy的库,from easy to elegent 故起名为freeRL,

free也是希望写出的代码可以随意的,自由的从此代码移植到自己的代码上。

python环境

python 复制代码
python 3.11.9
torch 2.3.1+cu121
gymnasium[all] 0.29.1
pygame 0.25.2 # 这个版本和gymnasium[all]0.29.1兼容

效果

在参数没有精细调整的情况下,在大多数的环境已能适用。

用DQN算法在LunarLander-v2环境下训练500个轮次的3个seed的效果:线为均值,阴影为方差

用 seed = 0 训练的模型评估,评估100个不同的seed的结果。

随机选择其中一个seed的结果,渲染环境。

已复现结果

1.DQN

2.DQN_Double

3.DQN_Dueling

4.DQN_PER

5.DQN_Noisy

6.DQN_N_Step

7.DQN_Categorical

8.DQN_Rainbow

其中:

1 实现在DQN_file/DQN.py

2-8 实现在DQN_file/DQN_with_tricks.py

综述

为了便于对算法的理解和改动,我将一个整体的算法训练和评估分离开来。

python 复制代码
DQN_file
├── learning_curves
│   ├── env_name_1
│	│   ├── DQN_3_seed.npy
│   │   └── DQN.png
│   └── env_name_2
├── results
│   ├── env_name_1
│	│	├── DQN_1
│	│	│	├── DQN_seed_0.npy
│	│	│	├── DQN.pt
│	│	│	├── evaluate.gif
│	│	│	├── evaluate.png
│	│	│	└── events.out.tfevents.
│	│	├── DQN_2
│	│	└── DQN_3
│   └── env_name_2
├── plot_learning_curves.py
├── evaluate.py
├── Buffer.py
└── DQN.py

首先看最下面几个具体的py文件

1.evaluate.py 实现评估。

2.plot_learning_curves.py实现多个seed的学习曲线的绘制和算法比较。

3.DQN.py 实现算法。

4.Buffer.py 实现经验池,经验池基本通用。

以DQN.py为算法.py举例

DQN.py(主要)

建议边打开github上DQN.py的代码边看。

算法实现

一个深度强化学习算法分三个部分实现:

1.Agent类:包括actor、critic、target_actor、target_critic、actor_optimizer、critic_optimizer、

2.DQN算法类:包括select_action,learn、save、load等方法,为具体的算法细节实现

3.main函数:实例化DQN类,主要参数的设置,训练、测试、保存模型等

这三个部分均在DQN.py里实现。

参数修改

参数修改 改三处:

1.MLP的hidden (此参数往往在第一部分开头实现)

2.main中args

3.dis_to_con中的离散转连续空间维度(针对无法转成连续域的算法,例:DQN)

对于1.需要单独修改的理由

hidden的层数和个数容易变化,且RL的许多的算法创新实现在MLP(Qnet,Actor,Critic处)会有新增参数。

对于2.

args 为主要的参数,算法独有或共有或保存位置的修改。

对于3.

主要针对DQN只能对离散环境适用,不能对连续环境适用,进行的转换。

将动作分配成多维的离散动作,使得算法可以适用,相对的,在采样环境时,需要将离散的动作转换成连续的动作。

基本的参数没有精细调整,这里DQN使用离散环境MountainCar-v0为基准来调整参数,以此能收敛为目标了,后发现此参数可以适用大多数其他环境,但不是全部。

使用MountainCar-v0的理由:环境的目标是到达最高的山峰,但环境中还有个次高的山峰,个人认为可以很好拟合出梯度中的次优解。

细节实现

1.对于不同的算法的实现,在代码中给出论文链接和不同实现。

2.在RL中使用常用的,通用的pytorch代码,易懂。见:【深度强化学习】常常使用的pytorch代码

3.区分env的terminated,truncated

4.区分训练时用的action(例:(-1,1))和env能接受的action_(例:(-3,3))

(区分3,4两点对于收敛有很大帮助。)

5.区分环境采样过程和训练过程,以提高算法的拓展性。

6.以max_episodes为终止条件,但是训练以step为最小单位。

显示训练,保存训练

1.训练时,使用tensorboard来显示实时的学习曲率。

在DQN_file(算法)文件夹下,D:FreeRL/DQN_file 终端里输入:
tensorboard --logdir=results/env_name

在跳出的http://localhost:6008/ 按住ctrl点击进入就行。

tensorboard保存的文件events.out.tfevents.和模型的位置一致。

保存模型的频率设置为总回合的1/4。

2.在results文件夹下,不同环境为文件夹名下,在算法(或算法+trick)为文件夹名里,(results/env_name/DQN_1)保存模型文件(DQN.pt)及其训练时每个episode的return值,以不同seed为区分(DQN_seed_0.npy)(此npy用于后续画学习曲率)

每进行一次训练文件夹后面的数DQN_n,n+1。

Buffer.py

在创建buffer时直接使用zeros来创建,比使用deque来创建在最后使用python基本数据再转成numpy再转成tensor速度要快。

这里使用numpy实现来使它更快一点。(参考elegentrl)

其他一些buffer的实现,都实现在此。

evaluate.py

实现对模型的评估,可设定评估的轮次数,设定是否保存渲染环境gif。

这里seed的设定值须与训练的seed值不同。

由于gymnasium可以设定env的seed。这里将环境的seed值设定为当前遍历的轮次,以实现seed的改变。

在gymnasium中,如果有实现任务所达到的return值,在画评估图时,以此为基线。

环境gif的保存,则是随机挑选其中一个回合进行保存。

此代码所得到的evaluate.png,evaluate.gif均保存在模型所在位置。(results/env/DQN_1/下)

(上述效果的最后两个图)

learning_curves

1.将不同的results/env/algorithm_trick_n下的DQN_seed_n.npy绘制成一个学习曲线

以均值为线,阴影为方差。

2.将比较的多个seed的episode_return 另保存为DQN_3_seed.npy方便后续比较。

3.可以选择是否比较此算法的其他trick算法。

可以设置seed_num大小,取决于你在环境的测试中,实验了几次不同的seed大小,这里仅使用seed =

0,10,100来进行绘制,当然也可以只进行一个seed的绘制。(这里有进行平滑处理,可以设置)

生成的学习曲线图为DQN.py 和保存的DQN_3_seed.npy保存在learning_curves/env/下

(上述效果 的第一张图为学习曲线图,已复现的结果为比较图)

相关推荐
点云SLAM7 分钟前
CVPR 2024 人脸方向总汇(人脸识别、头像重建、人脸合成和3D头像等)
深度学习·计算机视觉·人脸识别·3d人脸·头像重建
涛涛讲AI18 分钟前
扣子平台音频功能:让声音也能“智能”起来
人工智能·音视频·工作流·智能体·ai智能体·ai应用
霍格沃兹测试开发学社测试人社区20 分钟前
人工智能在音频、视觉、多模态领域的应用
软件测试·人工智能·测试开发·自动化·音视频
herosunly40 分钟前
2024:人工智能大模型的璀璨年代
人工智能·大模型·年度总结·博客之星
PaLu-LI1 小时前
ORB-SLAM2源码学习:Initializer.cc(13): Initializer::ReconstructF用F矩阵恢复R,t及三维点
c++·人工智能·学习·线性代数·ubuntu·计算机视觉·矩阵
呆呆珝1 小时前
RKNN_C++版本-YOLOV5
c++·人工智能·嵌入式硬件·yolo
笔触狂放1 小时前
第一章 语音识别概述
人工智能·python·机器学习·语音识别
ZzYH221 小时前
文献阅读 250125-Accurate predictions on small data with a tabular foundation model
人工智能·笔记·深度学习·机器学习
格林威1 小时前
BroadCom-RDMA博通网卡如何进行驱动安装和设置使得对应网口具有RDMA功能以适配RDMA相机
人工智能·数码相机·opencv·计算机视觉·c#
FL16238631291 小时前
汽车表面划痕刮伤检测数据集VOC+YOLO格式1221张1类别
深度学习·yolo·汽车