强化学习原理python篇06——DQN

强化学习原理python篇05------DQN

本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

DQN 算法

1)使用随机权重 ( w ← 1.0 ) (w←1.0) (w←1.0)初始化目标网络 Q ( s , a , w ) Q(s, a, w) Q(s,a,w)和网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w), Q Q Q和 Q ^ \hat Q Q^相同,清空回放缓冲区。

2)以概率ε选择一个随机动作a,否则 a = a r g m a x Q ( s , a , w ) a=argmaxQ(s,a,w) a=argmaxQ(s,a,w)。

3)在模拟器中执行动作a,观察奖励r和下一个状态s'。

4)将转移过程(s, a, r, s')存储在回放缓冲区中。

5)从回放缓冲区中采样一个随机的小批量转移过程。

6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标 y = r y=r y=r,否则计算 y = r + γ m a x Q ^ ( s , a , w ) y=r+\gamma max \hat Q(s, a, w) y=r+γmaxQ^(s,a,w) 。

7)计算损失: L = ( Q ( s , a , w ) -- y ) 2 L=(Q(s, a, w)--y)^2 L=(Q(s,a,w)--y)2。

8)固定网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)不变,通过最小化模型参数的损失,使用SGD算法更新 Q ( s , a ) Q(s, a) Q(s,a)。

9)每N步,将权重从目标网络 Q Q Q复制到 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) 。

10)从步骤2开始重复,直到收敛为止。

定义DQN网络

python 复制代码
import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriter

class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, q_table_size):
        super(Net, self).__init__()

        self.net = nn.Sequential(
            # 输入为状态,样本为(1*n)
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            # nn.Linear(hidden_size, hidden_size),
            # nn.ReLU(),
            nn.Linear(hidden_size, q_table_size),
        )

    def forward(self, state):
        return self.net(state)


class DQN:
    def __init__(self, env, tgt_net, net):
        self.env = env
        self.tgt_net = tgt_net
        self.net = net

    def generate_train_data(self, batch_size, epsilon):

        state, _ = env.reset()
        train_data = []
        while len(train_data)<batch_size*2:
            q_table_tgt = self.tgt_net(torch.Tensor(state)).detach()
            if np.random.uniform(0, 1, 1) > epsilon:
                action = self.env.action_space.sample()
            else:
                action = int(torch.argmax(q_table_tgt))
            new_state, reward,terminated, truncted, info = env.step(action)
            train_data.append([state, action, reward, new_state, terminated])
            state = new_state
            if terminated:
                state, _ = env.reset()
                continue
        random.shuffle(train_data)                
        return train_data[:batch_size]

    def calculate_y_hat_and_y(self, batch):
        # 6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标$y=r$,否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。
        y = []
        state_space = []
        action_space = []
        for state, action, reward, new_state, terminated in batch:
            # y值
            if terminated:
                y.append(reward)
            else:
                # 下一步的 qtable 的最大值
                q_table_net = self.net(torch.Tensor(np.array([new_state]))).detach()
                y.append(reward + gamma * float(torch.max(q_table_net)))
            # y hat的值
            state_space.append(state)
            action_space.append(action)
        idx = [list(range(len(action_space))), action_space]
        y_hat = self.tgt_net(torch.Tensor(np.array(state_space)))[idx]
        return y_hat, torch.tensor(y)

    def update_net_parameters(self, update=True):
        self.net.load_state_dict(self.tgt_net.state_dict())
      

初始化环境

python 复制代码
   # 初始化环境
env = gym.make("CartPole-v1")
# env = DiscreteOneHotWrapper(env)

hidden_num = 64
# 定义网络
net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn = DQN(env=env, net=net, tgt_net=tgt_net)

# 初始化参数
# dqn.init_net_and_target_net_weight()

# 定义优化器
opt = optim.Adam(tgt_net.parameters(), lr=0.001)


# 定义损失函数
loss = nn.MSELoss()

# 记录训练过程
# writer = SummaryWriter(log_dir="logs/DQN", comment="DQN")

开始训练

python 复制代码
gamma = 0.8
for i in range(10000):
    batch = dqn.generate_train_data(256, 0.8)
    y_hat, y = dqn.calculate_y_hat_and_y(batch)
    opt.zero_grad()
    l = loss(y_hat, y)
    l.backward()
    opt.step()

    print("MSE: {}".format(l.item()))
    if i % 5 == 0:
        dqn.update_net_parameters(update=True)

输出:

复制代码
MSE: 0.027348674833774567
MSE: 0.1803671419620514
MSE: 0.06523636728525162
MSE: 0.08363766968250275
MSE: 0.062360599637031555
MSE: 0.004909628536552191
MSE: 0.05730309337377548
MSE: 0.03543371334671974
MSE: 0.08458714932203293

可视化结果

python 复制代码
env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")

state, info = env.reset()
total_rewards = 0

while True:
    q_table_state = dqn.tgt_net(torch.Tensor(state)).detach()
    # if np.random.uniform(0, 1, 1) > 0.9:
    #     action = env.action_space.sample()
    # else:
    action = int(torch.argmax(q_table_state))
    state, reward, terminated, truncted, info = env.step(action)
    if terminated:
        break
          
相关推荐
艾莉丝努力练剑22 分钟前
【C语言】学习过程教训与经验杂谈:思想准备、知识回顾(三)
c语言·开发语言·数据结构·学习·算法
Chasing__Dreams29 分钟前
python--杂识--18.1--pandas数据插入sqlite并进行查询
python·sqlite·pandas
彭泽布衣1 小时前
python2.7/lib-dynload/_ssl.so: undefined symbol: sk_pop_free
python·sk_pop_free
witton2 小时前
Go语言网络游戏服务器模块化编程
服务器·开发语言·游戏·golang·origin·模块化·耦合
喜欢吃豆2 小时前
从零构建MCP服务器:FastMCP实战指南
运维·服务器·人工智能·python·大模型·mcp
一个处女座的测试2 小时前
Python语言+pytest框架+allure报告+log日志+yaml文件+mysql断言实现接口自动化框架
python·mysql·pytest
枯萎穿心攻击2 小时前
ECS由浅入深第三节:进阶?System 的行为与复杂交互模式
开发语言·unity·c#·游戏引擎
Jerry Lau3 小时前
go go go 出发咯 - go web开发入门系列(一) helloworld
开发语言·前端·golang
nananaij3 小时前
【Python基础入门 re模块实现正则表达式操作】
开发语言·python·正则表达式
Micro麦可乐3 小时前
Java常用加密算法详解与实战代码 - 附可直接运行的测试示例
java·开发语言·加密算法·aes加解密·rsa加解密·hash算法