基于 Q-Learning 算法和 CNN 的强化学习实现方案

import numpy as np

import tensorflow as tf

from collections import deque

import random

class QLearningCNN:

def init(self, state_shape, action_size, learning_rate=0.001,

gamma=0.99, epsilon=1.0, epsilon_min=0.01,

epsilon_decay=0.995, memory_size=10000, batch_size=32):

self.state_shape = state_shape # 输入状态的形状,例如(84, 84, 4)

self.action_size = action_size # 动作空间大小

self.learning_rate = learning_rate

self.gamma = gamma # 折扣因子

self.epsilon = epsilon # 探索率

self.epsilon_min = epsilon_min

self.epsilon_decay = epsilon_decay

self.memory = deque(maxlen=memory_size) # 经验回放缓冲区

self.batch_size = batch_size

创建主网络和目标网络

self.model = self._build_model()

self.target_model = self._build_model()

self.update_target_model()

def _build_model(self):

"""构建用于Q值近似的CNN模型"""

model = tf.keras.Sequential([

卷积层1

tf.keras.layers.Conv2D(32, (8, 8), strides=(4, 4), activation='relu',

input_shape=self.state_shape),

卷积层2

tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), activation='relu'),

卷积层3

tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), activation='relu'),

全连接层

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(512, activation='relu'),

输出层,每个动作对应一个Q值

tf.keras.layers.Dense(self.action_size, activation='linear')

])

model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate))

return model

def update_target_model(self):

"""更新目标网络权重,从主网络复制权重"""

self.target_model.set_weights(self.model.get_weights())

def remember(self, state, action, reward, next_state, done):

"""将经验存储到回放缓冲区"""

self.memory.append((state, action, reward, next_state, done))

def act(self, state):

"""根据当前状态选择动作,使用ε-贪婪策略"""

if np.random.rand() <= self.epsilon:

随机探索

return random.randrange(self.action_size)

利用当前模型选择最优动作

act_values = self.model.predict(state, verbose=0)

return np.argmax(act_values[0])

def replay(self):

"""从经验回放中学习"""

if len(self.memory) < self.batch_size:

return

从经验回放缓冲区中随机采样一批数据

minibatch = random.sample(self.memory, self.batch_size)

states = np.array([i[0] for i in minibatch])

actions = np.array([i[1] for i in minibatch])

rewards = np.array([i[2] for i in minibatch])

next_states = np.array([i[3] for i in minibatch])

dones = np.array([i[4] for i in minibatch])

states = np.squeeze(states)

next_states = np.squeeze(next_states)

计算目标Q值

target = rewards + self.gamma * np.amax(self.target_model.predict_on_batch(next_states), axis=1) * (1 - dones)

选择当前预测值

target_f = self.model.predict_on_batch(states)

更新目标动作的Q值

for i, action in enumerate(actions):

target_f[i][action] = target[i]

训练网络

self.model.fit(states, target_f, epochs=1, verbose=0)

衰减探索率

if self.epsilon > self.epsilon_min:

self.epsilon *= self.epsilon_decay

def load(self, name):

"""加载模型权重"""

self.model.load_weights(name)

self.update_target_model()

def save(self, name):

"""保存模型权重"""

self.model.save_weights(name)

使用示例

if name == "main":

示例参数,适用于处理84x84像素的游戏画面

state_shape = (84, 84, 4) # 4帧堆叠的游戏画面

action_size = 4 # 假设游戏有4个可能的动作

创建代理

agent = QLearningCNN(state_shape, action_size)

训练循环(伪代码示例)

episodes = 1000

for e in range(episodes):

初始化环境

state = env.reset() # 假设env是游戏环境

state = np.reshape(state, [1, 84, 84, 4])

done = False

total_reward = 0

while not done:

选择动作

action = agent.act(state)

执行动作

next_state, reward, done, _ = env.step(action)

next_state = np.reshape(next_state, [1, 84, 84, 4])

存储经验

agent.remember(state, action, reward, next_state, done)

更新状态

state = next_state

total_reward += reward

训练网络

agent.replay()

定期更新目标网络

if e % 10 == 0:

agent.update_target_model()

print(f"Episode: {e+1}/{episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")

这个实现包含了以下关键组件:

  1. CNN 网络结构:使用三层卷积层处理图像输入,提取特征
  2. Q-Learning 算法:使用 ε- 贪婪策略进行探索和利用
  3. 经验回放机制:通过存储和随机采样历史经验提高学习效率
  4. 目标网络:使用独立的目标网络提高训练稳定性

要使用这个代码,你需要安装以下依赖:

  • tensorflow >= 2.0
  • numpy
  • random
相关推荐
lifallen4 分钟前
深入浅出 Arrays.sort(DualPivotQuicksort):如何结合快排、归并、堆排序和插入排序
java·开发语言·数据结构·算法·排序算法
jingfeng5145 分钟前
数据结构排序
数据结构·算法·排序算法
能工智人小辰31 分钟前
Codeforces Round 509 (Div. 2) C. Coffee Break
c语言·c++·算法
kingmax5421200832 分钟前
CCF GESP202503 Grade4-B4263 [GESP202503 四级] 荒地开垦
数据结构·算法
carpell33 分钟前
【语义分割专栏】3:Segnet实战篇(附上完整可运行的代码pytorch)
人工智能·python·深度学习·计算机视觉·语义分割
岁忧37 分钟前
LeetCode 高频 SQL 50 题(基础版)之 【高级字符串函数 / 正则表达式 / 子句】· 上
sql·算法·leetcode
智能汽车人1 小时前
自动驾驶---SD图导航的规划策略
人工智能·机器学习·自动驾驶
mengyoufengyu1 小时前
DeepSeek11-Ollama + Open WebUI 搭建本地 RAG 知识库全流程指南
人工智能·深度学习·deepseek
Tianyanxiao1 小时前
华为×小鹏战略合作:破局智能驾驶深水区的商业逻辑深度解析
大数据·人工智能·经验分享·华为·金融·数据分析
rit84324991 小时前
基于BP神经网络的语音特征信号分类
人工智能·神经网络·分类