基于 Q-Learning 算法和 CNN 的强化学习实现方案

import numpy as np

import tensorflow as tf

from collections import deque

import random

class QLearningCNN:

def init(self, state_shape, action_size, learning_rate=0.001,

gamma=0.99, epsilon=1.0, epsilon_min=0.01,

epsilon_decay=0.995, memory_size=10000, batch_size=32):

self.state_shape = state_shape # 输入状态的形状,例如(84, 84, 4)

self.action_size = action_size # 动作空间大小

self.learning_rate = learning_rate

self.gamma = gamma # 折扣因子

self.epsilon = epsilon # 探索率

self.epsilon_min = epsilon_min

self.epsilon_decay = epsilon_decay

self.memory = deque(maxlen=memory_size) # 经验回放缓冲区

self.batch_size = batch_size

创建主网络和目标网络

self.model = self._build_model()

self.target_model = self._build_model()

self.update_target_model()

def _build_model(self):

"""构建用于Q值近似的CNN模型"""

model = tf.keras.Sequential([

卷积层1

tf.keras.layers.Conv2D(32, (8, 8), strides=(4, 4), activation='relu',

input_shape=self.state_shape),

卷积层2

tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), activation='relu'),

卷积层3

tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), activation='relu'),

全连接层

tf.keras.layers.Flatten(),

tf.keras.layers.Dense(512, activation='relu'),

输出层,每个动作对应一个Q值

tf.keras.layers.Dense(self.action_size, activation='linear')

])

model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(learning_rate=self.learning_rate))

return model

def update_target_model(self):

"""更新目标网络权重,从主网络复制权重"""

self.target_model.set_weights(self.model.get_weights())

def remember(self, state, action, reward, next_state, done):

"""将经验存储到回放缓冲区"""

self.memory.append((state, action, reward, next_state, done))

def act(self, state):

"""根据当前状态选择动作,使用ε-贪婪策略"""

if np.random.rand() <= self.epsilon:

随机探索

return random.randrange(self.action_size)

利用当前模型选择最优动作

act_values = self.model.predict(state, verbose=0)

return np.argmax(act_values[0])

def replay(self):

"""从经验回放中学习"""

if len(self.memory) < self.batch_size:

return

从经验回放缓冲区中随机采样一批数据

minibatch = random.sample(self.memory, self.batch_size)

states = np.array([i[0] for i in minibatch])

actions = np.array([i[1] for i in minibatch])

rewards = np.array([i[2] for i in minibatch])

next_states = np.array([i[3] for i in minibatch])

dones = np.array([i[4] for i in minibatch])

states = np.squeeze(states)

next_states = np.squeeze(next_states)

计算目标Q值

target = rewards + self.gamma * np.amax(self.target_model.predict_on_batch(next_states), axis=1) * (1 - dones)

选择当前预测值

target_f = self.model.predict_on_batch(states)

更新目标动作的Q值

for i, action in enumerate(actions):

target_f[i][action] = target[i]

训练网络

self.model.fit(states, target_f, epochs=1, verbose=0)

衰减探索率

if self.epsilon > self.epsilon_min:

self.epsilon *= self.epsilon_decay

def load(self, name):

"""加载模型权重"""

self.model.load_weights(name)

self.update_target_model()

def save(self, name):

"""保存模型权重"""

self.model.save_weights(name)

使用示例

if name == "main":

示例参数,适用于处理84x84像素的游戏画面

state_shape = (84, 84, 4) # 4帧堆叠的游戏画面

action_size = 4 # 假设游戏有4个可能的动作

创建代理

agent = QLearningCNN(state_shape, action_size)

训练循环(伪代码示例)

episodes = 1000

for e in range(episodes):

初始化环境

state = env.reset() # 假设env是游戏环境

state = np.reshape(state, [1, 84, 84, 4])

done = False

total_reward = 0

while not done:

选择动作

action = agent.act(state)

执行动作

next_state, reward, done, _ = env.step(action)

next_state = np.reshape(next_state, [1, 84, 84, 4])

存储经验

agent.remember(state, action, reward, next_state, done)

更新状态

state = next_state

total_reward += reward

训练网络

agent.replay()

定期更新目标网络

if e % 10 == 0:

agent.update_target_model()

print(f"Episode: {e+1}/{episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")

这个实现包含了以下关键组件:

  1. CNN 网络结构:使用三层卷积层处理图像输入,提取特征
  2. Q-Learning 算法:使用 ε- 贪婪策略进行探索和利用
  3. 经验回放机制:通过存储和随机采样历史经验提高学习效率
  4. 目标网络:使用独立的目标网络提高训练稳定性

要使用这个代码,你需要安装以下依赖:

  • tensorflow >= 2.0
  • numpy
  • random
相关推荐
艾莉丝努力练剑3 分钟前
【C/C++】类和对象(上):(一)类和结构体,命名规范——两大规范,新的作用域——类域
java·c语言·开发语言·c++·学习·算法
AndrewHZ15 分钟前
【图像处理基石】如何对遥感图像进行实例分割?
图像处理·人工智能·python·大模型·实例分割·detectron2·遥感图像分割
TDengine (老段)33 分钟前
TDengine 中 TDgp 中添加机器学习模型
大数据·数据库·算法·机器学习·数据分析·时序数据库·tdengine
CodeShare43 分钟前
某中心将举办机器学习峰会
人工智能·机器学习·数据科学
那就摆吧1 小时前
U-Net vs. 传统CNN:为什么医学图像分割需要跳过连接?
人工智能·神经网络·cnn·u-net·医学图像
深度学习实战训练营1 小时前
中英混合的语音识别XPhoneBERT 监督的音频到音素的编码器结合 f0 特征LID
人工智能·音视频·语音识别
WADesk---瓜子1 小时前
用 AI 自动生成口型同步视频,短视频内容也能一人完成
人工智能·音视频·语音识别·流量运营·用户运营
星环科技TDH社区版1 小时前
AI Agent 的 10 种应用场景:物联网、RAG 与灾难响应
人工智能·物联网
时序之心2 小时前
ICML 2025 | 深度剖析时序 Transformer:为何有效,瓶颈何在?
人工智能·深度学习·transformer