REINFORCE算法

1.算法描述

REINFORCE 算法是基于蒙特卡洛采样的无模型策略梯度方法,由 Williams 于 1992 年提出。其核心思想是:利用完整轨迹采样得到的未来累积回报 Gt​ 加权策略梯度,优化策略参数;通过增大高回报轨迹中动作的概率、降低低回报轨迹中动作的概率,从而提升策略性能。

2.代码实现

python 复制代码
# -*- coding: utf-8 -*-

import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class GridWorld:
    def __init__(self, size=5):
        self.size = size
        self.goal = (size - 1, size - 1)
        self.reset()

    def reset(self):
        self.agent = (0, 0)
        return self._get_state()

    def _get_state(self):
        state = np.zeros((self.size, self.size), dtype=np.float32)
        state[self.agent] = 1.0
        return state.flatten()

    def step(self, action):
        x, y = self.agent
        if action == 0:   # up
            x -= 1
        elif action == 1: # down
            x += 1
        elif action == 2: # left
            y -= 1
        elif action == 3: # right
            y += 1

        x = np.clip(x, 0, self.size - 1)
        y = np.clip(y, 0, self.size - 1)
        self.agent = (x, y)

        if self.agent == self.goal:
            reward = 10.0
            done = True
        else:
            reward = -0.1
            done = False
        return self._get_state(), reward, done

class PolicyNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.softmax(self.fc3(x), dim=-1)

class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99):
        self.gamma = gamma
        self.policy = PolicyNet(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)

    def select_action(self, state):
        state = torch.from_numpy(state).float().to(device)
        probs = self.policy(state)
        m = torch.distributions.Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

    def update(self, rewards, log_probs):
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)

        returns = torch.tensor(returns, dtype=torch.float32).to(device)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        # 计算 Loss: -log_prob * Gt
        loss = []
        for log_prob, Gt in zip(log_probs, returns):
            loss.append(-log_prob * Gt)
        
        loss = torch.stack(loss).sum()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

def train():
    env = GridWorld(size=5)
    state_dim = 25
    action_dim = 4

    agent = REINFORCEAgent(state_dim, action_dim)
    rewards_history = []

    for episode in range(800):
        state = env.reset()
        log_probs, rewards = [], []

        for step in range(50):
            action, log_prob = agent.select_action(state)
            next_state, reward, done = env.step(action)

            log_probs.append(log_prob)
            rewards.append(reward)

            state = next_state
            if done:
                break

        agent.update(rewards, log_probs)
        total_reward = sum(rewards)
        rewards_history.append(total_reward)

        if episode % 50 == 0:
            print(f"Episode {episode}, total reward = {total_reward:.2f}")

    return agent, env, rewards_history

def visualize_policy(agent, size=5):
    arrow_map = {0: (0, 0.4), 1: (0, -0.4), 2: (-0.4, 0), 3: (0.4, 0)}

    plt.figure(figsize=(6, 6))
    for x in range(size):
        for y in range(size):
            state = np.zeros((size, size), dtype=np.float32)
            state[x, y] = 1.0
            state_tensor = torch.from_numpy(state.flatten()).to(device)
            
            with torch.no_grad():
                probs = agent.policy(state_tensor)
                action = torch.argmax(probs).item()
            
            dx, dy = arrow_map[action]
            plt.arrow(y, size - 1 - x, dx, dy, 
                      head_width=0.15, length_includes_head=True, color='blue')

    plt.scatter(size - 1, 0, c="red", s=200, marker="*", label="Goal")
    plt.grid(True)
    plt.xticks(range(size))
    plt.yticks(range(size))
    plt.title("Learned Policy (PyTorch REINFORCE)")
    plt.legend()
    plt.show()

if __name__ == "__main__":
    trained_agent, env, history = train()
    visualize_policy(trained_agent)

3.效果展示

bash 复制代码
Using device: cpu
Episode 0, total reward = 5.40
Episode 50, total reward = -5.00
Episode 100, total reward = 9.20
Episode 150, total reward = 8.90
Episode 200, total reward = 9.30
Episode 250, total reward = 9.10
Episode 300, total reward = 9.30
Episode 350, total reward = 9.30
Episode 400, total reward = 9.10
Episode 450, total reward = 9.10
Episode 500, total reward = 9.20
Episode 550, total reward = 9.20
Episode 600, total reward = 9.30
Episode 650, total reward = 9.30
Episode 700, total reward = 9.30
Episode 750, total reward = 9.30
相关推荐
把你拉进白名单5 小时前
7.OpenClaw源码解析——可靠消息投递
人工智能·llm·agent
劈星斩月5 小时前
机器学习之 定义与三大范式
人工智能·机器学习·监督学习·强化学习·无监督学习
触底反弹5 小时前
🎨 通义万相实战:用 Qwen 多模态 API 实现 AI 换装换姿势,10 行代码搞定!
vue.js·人工智能
属鼠哥5 小时前
一场正在发生的范式转变:Loop Engineering(循环工程)
人工智能·aiops
码农小旋风5 小时前
Claude Code 基础用法大全:对话、分析、修改、测试、Git 和工作流
人工智能·git·chatgpt·claude
Solis程序员5 小时前
MCP (Model Context Protocol):AI应用连接外部世界的标准协议
人工智能·microsoft·agent·skill·mcp
贵慜_Derek5 小时前
《从零实现 Agent 系统》连载 29|多 Agent 研究 Harness:Lead、Worker 与 Spawn
人工智能·架构·agent
枫子有风5 小时前
AI编程-Vibe coding(大厂常问问题)
人工智能·ai编程
枫叶林FYL5 小时前
BRIDGE:多模态查询的强化学习对齐与文本检索重构
人工智能·语言模型
leeyi5 小时前
Retriever 组件:让 Agent 学会「翻资料」的统一接口
人工智能·后端·agent