增强学习(Reinforcement Learning)简介

增强学习(Reinforcement Learning)简介

增强学习是机器学习的一种范式,其核心目标是让智能体(Agent)通过与环境的交互,基于试错机制和延迟奖励反馈,学习如何选择最优动作以最大化长期累积回报。其核心要素包括:

• 状态(State):描述环境的当前信息(如棋盘布局、机器人传感器数据)。

• 动作(Action):智能体在特定状态下可执行的操作(如移动、下棋)。

• 奖励(Reward):环境对动作的即时反馈信号(如得分增加或惩罚)。

• 策略(Policy):从状态到动作的映射规则(如基于Q值选择动作)。

• 价值函数(Value Function):预测某状态或动作的长期回报(如Q-Learning中的Q表)。

与监督学习不同,增强学习无需标注数据,而是通过探索-利用权衡(Exploration vs Exploitation)自主学习。


使用PyTorch实现深度Q网络(DQN)演示

以下以CartPole-v0(平衡杆环境)为例,展示完整代码及解释:

  1. 环境与依赖库
python 复制代码
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random

# 初始化环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
  1. 定义DQN网络
python 复制代码
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )
    
    def forward(self, x):
        return self.fc(x)
  1. 经验回放缓冲区(Replay Buffer)
python 复制代码
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(states),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(next_states),
            torch.FloatTensor(dones)
        )
  1. 训练参数与初始化
python 复制代码
# 超参数
batch_size = 64
gamma = 0.99        # 折扣因子
epsilon_start = 1.0
epsilon_decay = 0.995
epsilon_min = 0.01
target_update = 10  # 目标网络更新频率

# 初始化网络与优化器
policy_net = DQN(state_dim, action_dim)
target_net = DQN(state_dim, action_dim)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=0.001)
buffer = ReplayBuffer(10000)
epsilon = epsilon_start
  1. 训练循环
python 复制代码
num_episodes = 500
for episode in range(num_episodes):
    state = env.reset()
    total_reward = 0
    
    while True:
        # ε-贪婪策略选择动作
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                q_values = policy_net(torch.FloatTensor(state))
                action = q_values.argmax().item()
        
        # 执行动作并存储经验
        next_state, reward, done, _ = env.step(action)
        buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        
        # 经验回放与网络更新
        if len(buffer.buffer) >= batch_size:
            states, actions, rewards, next_states, dones = buffer.sample(batch_size)
            
            # 计算目标Q值
            with torch.no_grad():
                next_q = target_net(next_states).max(1)[0]
                target_q = rewards + gamma * next_q * (1 - dones)
            
            # 计算当前Q值
            current_q = policy_net(states).gather(1, actions.unsqueeze(1))
            
            # 均方误差损失
            loss = nn.MSELoss()(current_q, target_q.unsqueeze(1))
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if done:
            break
    
    # 更新目标网络与ε
    if episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())
    epsilon = max(epsilon_min, epsilon * epsilon_decay)
    
    print(f"Episode {episode}, Reward: {total_reward}, Epsilon: {epsilon:.2f}")

关键点解释

  1. 经验回放(Replay Buffer):通过存储历史经验并随机采样,打破数据相关性,提升训练稳定性。
  2. 目标网络(Target Network):固定目标Q值计算网络,缓解训练震荡问题。
  3. ε-贪婪策略:平衡探索(随机动作)与利用(最优动作),逐步降低探索率。

结果与优化方向

• 预期效果:经过约200轮训练,智能体可稳定保持平衡超过195步(CartPole-v1的胜利条件)。

• 优化方法:

• 使用Double DQN或Dueling DQN改进Q值估计。

• 调整网络结构(如增加卷积层处理图像输入)。

• 引入优先级经验回放(Prioritized Experience Replay)。

完整代码及更多改进可参考PyTorch官方文档或强化学习框架(如Stable Baselines3)。

相关推荐
独处东汉1 分钟前
freertos开发空气检测仪之按键输入事件管理系统设计与实现
人工智能·stm32·单片机·嵌入式硬件·unity
你大爷的,这都没注册了1 分钟前
AI提示词,zero-shot,few-shot 概念
人工智能
AC赳赳老秦2 分钟前
DeepSeek 辅助科研项目申报:可行性报告与经费预算框架的智能化撰写指南
数据库·人工智能·科技·mongodb·ui·rabbitmq·deepseek
瑞华丽PLM10 分钟前
国产PLM软件源头厂家的AI技术应用与智能化升级
人工智能·plm·国产plm·瑞华丽plm·瑞华丽
koo36416 分钟前
pytorch深度学习笔记19
pytorch·笔记·深度学习
xixixi7777719 分钟前
基于零信任架构的通信
大数据·人工智能·架构·零信任·通信·个人隐私
玄同76522 分钟前
LangChain v1.0+ Prompt 模板完全指南:构建精准可控的大模型交互
人工智能·语言模型·自然语言处理·langchain·nlp·交互·知识图谱
Ryan老房26 分钟前
开源vs商业-数据标注工具的选择困境
人工智能·yolo·目标检测·计算机视觉·ai
取个鸣字真的难32 分钟前
Obsidian + CC:用AI 打造知识管理系统
人工智能·产品运营
困死,根本不会1 小时前
OpenCV摄像头实时处理:基于 HSV 颜色空间的摄像头实时颜色筛选工具
人工智能·opencv·计算机视觉