Noisy DQN 跑 CartPole-v1

gym 0.26.1
CartPole-v1
NoisyNet DQN

NoisyNet 就是把原来Linear里的w/b 换成 mu + sigma * epsilon, 这是一种非常简单的方法,但是可以显著提升DQN的表现。

和之前最原始的DQN相比就是改了两个地方,一个是Linear改成了NoisyLinear,另外一个是在agenttake_action的时候策略 由ε-greedy改成了直接取argmax。详细见下面的代码。

本文的实现参考王树森的深度强化学习。

引用书上的一段话, 噪声DQN本身就带有随机性,可以鼓励探索,起到与ε-greedy策略相同的作用,直接用a_t = argmax Q(s,a,epsilon; mu,sigma), 作为行为策略,效果比ε-greedy更好。

python 复制代码
import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
import math

class ReplayBuffer:
    """经验回放池"""
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出
    
    def add(self, state, action, reward, next_state, done): # 将数据加入buffer
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size): # 从buffer中采样数据,数量为batch_size
        transition = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transition)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self): # 目前buffer中数据的数量
        return len(self.buffer)

class NoisyLinear(nn.Linear):
    def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
        super().__init__(in_features, out_features, bias)
        self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))
        self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features))
        if bias:
            self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))
            self.register_buffer("epsilon_bias", torch.zeros(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        std = math.sqrt(3 / self.in_features)
        self.weight.data.uniform_(-std, std)
        self.bias.data.uniform_(-std, std)
        
    def forward(self, x, is_training=True):
        self.epsilon_weight.normal_()
        bias = self.bias
        if bias is not None:
            self.epsilon_bias.normal_()
            bias = bias + self.sigma_bias * self.epsilon_bias.data
        if is_training:
            return F.linear(x, self.weight + self.sigma_weight * self.epsilon_weight.data, bias)
        else:
            return F.linear(x, self.weight, bias)

class Q(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.fc1 = NoisyLinear(state_dim, hidden_dim)
        self.fc2 = NoisyLinear(hidden_dim, action_dim)
    def forward(self, x, is_training=True):
        x = F.relu(self.fc1(x, is_training)) # 隐藏层之后使用ReLU激活函数
        return self.fc2(x, is_training)

class DQN:
    """DQN算法"""
    def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, target_update, device):
        self.action_dim = action_dim
        self.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络
        self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络
        self.target_q.load_state_dict(self.q.state_dict())  # 加载参数
        self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)
        self.gamma = gamma
        self.target_update = target_update # 目标网络更新频率
        self.count = 0 # 计数器,记录更新次数
        self.device = device
    
    def take_action(self, state): # 这个地方就不用epsilon-贪婪策略
        state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
        action = self.q(state).argmax().item()
        return action
    
    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)
        
        q_values = self.q(states).gather(1, actions) # Q值
        # 下个状态的最大Q值
        max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)
        q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差
        loss = F.mse_loss(q_values, q_targets) # 均方误差
        self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加
        loss.mean().backward() # 反向传播
        self.optimizer.step() # 更新梯度
        
        if self.count % self.target_update == 0:
            self.target_q.load_state_dict(self.q.state_dict())
        self.count += 1
python 复制代码
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = d2l.try_gpu()
print(device)

env_name = "CartPole-v1"
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, target_update, device)
return_list = []

for i in range(10):
    with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:
        for i_episode in range(int(num_episodes/10)):
            episode_return = 0
            state = env.reset()[0]
            done, truncated= False, False
            while not done and not truncated :
                action = agent.take_action(state)
                next_state, reward, done, truncated, info = env.step(action)
                replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if replay_buffer.size() > minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                    transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}
                    agent.update(transition_dict)
            return_list.append(episode_return)
            if (i_episode+1) % 10 == 0:
                pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 
                                  'return': '%.3f' % np.mean(return_list[-10:])})
            pbar.update(1)
            
episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

这次是在pycharm上运行jupyter file,结果如下:


效果对比之前的DQN 详细参考这篇 表现是显著提升。

相关推荐
四口鲸鱼爱吃盐4 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
leaf_leaves_leaf4 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零14 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
四口鲸鱼爱吃盐5 小时前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗5 小时前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
四口鲸鱼爱吃盐10 小时前
Pytorch | 利用VMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
四口鲸鱼爱吃盐10 小时前
Pytorch | 利用PI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
love you joyfully1 天前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
这个男人是小帅1 天前
【AutoDL】通过【SSH远程连接】【vscode】
运维·人工智能·pytorch·vscode·深度学习·ssh
四口鲸鱼爱吃盐1 天前
Pytorch | 利用MI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python