DQN 玩 2048 实战|第二期!设计 ε 贪心策略神经网络,简单训练一下吧!

视频链接:

DQN 玩 2048 实战|第二期!设计 ε 贪心策略神经网络,简单训练一下吧!

代码仓库:LitchiCheng/DRL-learning: 深度强化学习

概念介绍:

DQN(深度 Q 网络,Deep Q-Network)中,Q 的全称是 "Quality"(质量),对应的完整术语是"状态 - 动作值函数"(State-Action Value Function),记作 Q(s,a)

定义:Q(s,a) 表示在状态 s 下执行动作 a 后,智能体未来累积奖励的期望(即 "长期收益的质量")。

作用:

Q 值是强化学习中 "决策" 的核心依据。智能体通过比较当前状态下所有可能动作的 Q 值,选择 Q 值最大的动作(即 "最优动作"),以最大化累积奖励。

网络设计有三点:

  1. 深度 Q 网络定义:使用 PyTorch 定义一个神经网络,用于近似 Q 值函数。
  2. 经验回放机制:实现经验回放缓冲区,用于存储智能体的经验,并随机采样进行训练。
  3. 使用 Epsilon-greedy 策略,是一种平衡探索(Exploration)与利用(Exploitation)的经典策略,核心解决 "如何避免智能体只依赖已知最优动作,而错过潜在更好策略" 的问题。

下面是代码

复制代码
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.table import Table

# 2048 游戏环境类
class Game2048:
    def __init__(self):
        self.board = np.zeros((4, 4), dtype=int)
        self.add_random_tile()
        self.add_random_tile()

    def add_random_tile(self):
        empty_cells = np.argwhere(self.board == 0)
        if len(empty_cells) > 0:
            index = random.choice(empty_cells)
            self.board[index[0], index[1]] = 2 if random.random() < 0.9 else 4

    def move_left(self):
        reward = 0
        new_board = np.copy(self.board)
        for row in range(4):
            line = new_board[row]
            non_zero = line[line != 0]
            merged = []
            i = 0
            while i < len(non_zero):
                if i + 1 < len(non_zero) and non_zero[i] == non_zero[i + 1]:
                    merged.append(2 * non_zero[i])
                    reward += 2 * non_zero[i]
                    i += 2
                else:
                    merged.append(non_zero[i])
                    i += 1
            new_board[row] = np.pad(merged, (0, 4 - len(merged)), 'constant')
        if not np.array_equal(new_board, self.board):
            self.board = new_board
            self.add_random_tile()
        return reward

    def move_right(self):
        self.board = np.fliplr(self.board)
        reward = self.move_left()
        self.board = np.fliplr(self.board)
        return reward

    def move_up(self):
        self.board = self.board.T
        reward = self.move_left()
        self.board = self.board.T
        return reward

    def move_down(self):
        self.board = self.board.T
        reward = self.move_right()
        self.board = self.board.T
        return reward

    def step(self, action):
        if action == 0:
            reward = self.move_left()
        elif action == 1:
            reward = self.move_right()
        elif action == 2:
            reward = self.move_up()
        elif action == 3:
            reward = self.move_down()
        done = not np.any(self.board == 0) and all([
            np.all(self.board[:, i] != self.board[:, i + 1]) for i in range(3)
        ]) and all([
            np.all(self.board[i, :] != self.board[i + 1, :]) for i in range(3)
        ])
        state = self.board.flatten()
        return state, reward, done

    def reset(self):
        self.board = np.zeros((4, 4), dtype=int)
        self.add_random_tile()
        self.add_random_tile()
        return self.board.flatten()

# 深度 Q 网络类
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# 经验回放缓冲区类
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

    def __len__(self):
        return len(self.buffer)

# 可视化函数
def visualize_board(board, ax):
    ax.clear()
    table = Table(ax, bbox=[0, 0, 1, 1])
    nrows, ncols = board.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    # 定义颜色映射
    cmap = mcolors.LinearSegmentedColormap.from_list("", ["white", "yellow", "orange", "red"])

    for (i, j), val in np.ndenumerate(board):
        color = cmap(np.log2(val + 1) / np.log2(2048 + 1)) if val > 0 else "white"
        table.add_cell(i, j, width, height, text=val if val > 0 else "",
                       loc='center', facecolor=color)

    ax.add_table(table)
    ax.set_axis_off()
    plt.draw()
    plt.pause(0.1)

# 训练函数
def train():
    env = Game2048()
    input_size = 16
    output_size = 4
    model = DQN(input_size, output_size)
    target_model = DQN(input_size, output_size)
    target_model.load_state_dict(model.state_dict())
    target_model.eval()

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    replay_buffer = ReplayBuffer(capacity=10000)
    batch_size = 32
    gamma = 0.99
    epsilon = 1.0
    epsilon_decay = 0.995
    epsilon_min = 0.01
    update_target_freq = 10

    num_episodes = 1000
    fig, ax = plt.subplots()
    for episode in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state).unsqueeze(0)
        done = False
        total_reward = 0
        while not done:
            visualize_board(env.board, ax)
            if random.random() < epsilon:
                action = random.randint(0, output_size - 1)
            else:
                q_values = model(state)
                action = torch.argmax(q_values, dim=1).item()

            next_state, reward, done = env.step(action)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)
            replay_buffer.add(state.squeeze(0).numpy(), action, reward, next_state.squeeze(0).numpy(), done)

            if len(replay_buffer) >= batch_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
                states = torch.FloatTensor(states)
                actions = torch.LongTensor(actions)
                rewards = torch.FloatTensor(rewards)
                next_states = torch.FloatTensor(next_states)
                dones = torch.FloatTensor(dones)
                q_values = model(states)
                # 得到每个状态下实际采取动作的 Q 值
                q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
                next_q_values = target_model(next_states)
                # 得到下一个状态下最大的 Q 值
                next_q_values = next_q_values.max(1)[0]
                # 目标 Q 值
                target_q_values = rewards + gamma * (1 - dones) * next_q_values

                loss = criterion(q_values, target_q_values)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            state = next_state
            total_reward += reward

        if episode % update_target_freq == 0:
            target_model.load_state_dict(model.state_dict())

        epsilon = max(epsilon * epsilon_decay, epsilon_min)
        print(f"Episode {episode}: Total Reward = {total_reward}, Epsilon = {epsilon}")

    plt.close()

if __name__ == "__main__":
    train()

运行,会出现matplotlib可视化的2048操作过程,控制台输出当前训练的轮数等信息

相关推荐
ONLYOFFICE几秒前
集成 ONLYOFFICE 与 AI 插件,为您的服务带来智能文档编辑器
人工智能·ai·编辑器·onlyoffice·文档编辑器·文档预览·文档协作
一个天蝎座 白勺 程序猿6 分钟前
GpuGeek全栈AI开发实战:从零构建企业级大模型生产管线(附完整案例)
人工智能·gpugeek
love530love8 分钟前
家用或办公 Windows 电脑玩人工智能开源项目配备核显的必要性(含 NPU 及显卡类型补充)
人工智能·windows·python·开源·电脑
深圳市青牛科技实业有限公司10 分钟前
D2203使用手册—高压、小电流LDO产品4.6V~36V、150mA
人工智能·单片机·嵌入式硬件·电动工具·工业散热风扇
shengjk115 分钟前
序列化和反序列化:从理论到实践的全方位指南
java·大数据·开发语言·人工智能·后端·ai编程
AI大模型顾潇15 分钟前
[特殊字符] 本地大模型编程实战(29):用大语言模型LLM查询图数据库NEO4J(2)
前端·数据库·人工智能·语言模型·自然语言处理·prompt·neo4j
2501_9153743530 分钟前
数据清洗的艺术:如何为AI模型准备高质量数据集?
人工智能·机器学习
山北雨夜漫步33 分钟前
机器学习 Day17 朴素贝叶斯算法-----概率论知识
人工智能·算法·机器学习
愚公搬代码1 小时前
【愚公系列】《Manus极简入门》038-数字孪生设计师:“虚实映射师”
人工智能·agi·ai agent·智能体·manus
tongxianchao1 小时前
精简大语言模型:用于定制语言模型的自适应知识蒸馏
人工智能·语言模型·自然语言处理