TensorFlow 实现循环神经网络

摘要:本文介绍了使用TensorFlow实现循环神经网络(RNN)的方法,重点针对MNIST手写数字分类任务。主要内容包括:1) RNN的基本原理,通过序列式处理保留上下文信息;2) 具体实现步骤:数据预处理、参数定义、LSTM单元构建、损失函数和优化器配置;3) 训练过程展示,包括批次训练和准确率评估。实验结果表明,该方法在测试集上取得了良好的分类效果,验证了RNN处理序列数据的有效性。代码实现完整展示了从数据加载到模型评估的全流程,为RNN的TensorFlow实践提供了参考范例。

目录

[TensorFlow 实现循环神经网络](#TensorFlow 实现循环神经网络)

[基于 TensorFlow 的循环神经网络实现](#基于 TensorFlow 的循环神经网络实现)

[步骤 1:导入所需模块](#步骤 1:导入所需模块)

[步骤 2:定义输入参数](#步骤 2:定义输入参数)

[步骤 3:定义循环神经网络计算函数并配置损失函数与优化器](#步骤 3:定义循环神经网络计算函数并配置损失函数与优化器)

[步骤 4:启动计算图并训练模型](#步骤 4:启动计算图并训练模型)

模型运行输出结果


TensorFlow 实现循环神经网络

循环神经网络是一类面向深度学习的算法,采用序列式处理方法。在传统神经网络中,我们通常假设每个输入和输出都与其他所有层相互独立,而循环神经网络之所以被称为 "循环",是因为它会以序列的方式执行数学运算。

以下是训练循环神经网络的具体步骤:

  1. 从数据集中输入一个特定的样本;
  2. 网络接收该样本,并利用随机初始化的变量完成相关计算;
  3. 计算得到预测结果;
  4. 将实际输出结果与预期值对比,得到误差值;
  5. 沿原计算路径反向传播误差,同时调整相关变量;
  6. 重复步骤 1 至步骤 5,直至确定用于输出结果的变量已得到合理定义;
  7. 应用这些优化后的变量,对未见过的新输入数据进行系统性的预测。

循环神经网络的示意图表示如下:

基于 TensorFlow 的循环神经网络实现

本节将介绍如何使用 TensorFlow 实现循环神经网络,具体步骤如下:

步骤 1:导入所需模块

TensorFlow 提供了多个专用库,用于实现循环神经网络模块,通过以下代码导入核心模块:

python 复制代码
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot = True)

上述库的核心作用是定义输入数据,这是实现循环神经网络的基础环节。

步骤 2:定义输入参数

我们的核心目标是利用循环神经网络对图像进行分类,将每张图像的行视为一个像素序列。MNIST 数据集的图像尺寸固定为 28×28 像素,因此每个样本需处理 28 个序列,每个序列包含 28 个步骤,以下是输入参数的定义代码:

python

运行

python 复制代码
n_input = 28  # MNIST数据输入,图像尺寸28*28
n_steps = 28   # 序列步数
n_hidden = 128 # 隐藏层神经元数量
n_classes = 10 # 分类类别数(0-9数字)

# 定义TensorFlow计算图的输入占位符
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])

# 定义权重和偏置项
weights = {
    'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([n_classes]))
}

步骤 3:定义循环神经网络计算函数并配置损失函数与优化器

通过自定义函数实现循环神经网络的核心计算逻辑,对比数据形状与当前输入形状,保证计算精度,同时定义损失函数、优化器和模型评估指标:

python 复制代码
def RNN(x, weights, biases):
    # 将输入数据按序列维度拆解
    x = tf.unstack(x, n_steps, 1)
    # 定义LSTM细胞单元
    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # 获取LSTM细胞的输出和状态
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype = tf.float32)
    # 对最后一个时间步的输出做线性激活,得到预测结果
    return tf.matmul(outputs[-1], weights['out']) + biases['out']

# 得到模型预测值
pred = RNN(x, weights, biases)
# 定义交叉熵损失函数
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred, labels = y))
# 定义Adam优化器,最小化损失函数
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cost)
# 计算模型预测准确率
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# 初始化所有全局变量
init = tf.global_variables_initializer()

步骤 4:启动计算图并训练模型

启动 TensorFlow 计算图执行计算,完成模型训练并测试模型准确率:

python 复制代码
with tf.Session() as sess:
    # 初始化变量
    sess.run(init)
    step = 1
    # 迭代训练,直至达到最大迭代次数
    while step * batch_size < training_iters:
        # 获取批次训练数据
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # 调整数据形状以匹配模型输入
        batch_x = batch_x.reshape((batch_size, n_steps, n_input))
        # 执行优化步骤
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        # 定期打印训练结果
        if step % display_step == 0:
            # 计算批次准确率
            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
            # 计算批次损失值
            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
            # 打印迭代次数、损失值和准确率
            print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
                "{:.6f}".format(loss) + ", Training Accuracy= " + \
                "{:.5f}".format(acc))
        step += 1
    # 打印训练完成提示
    print("Optimization Finished!")
    # 定义测试数据量
    test_len = 128
    # 准备测试数据并调整形状
    test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
    test_label = mnist.test.labels[:test_len]
    # 打印测试准确率
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

模型运行输出结果

执行上述代码的终端命令及输出如下:

plaintext

python 复制代码
E:\Tensorflowproject>activate tensorflow
(tensorflow) E:\TensorFlowProject>python recurrent_network.py

运行过程中会出现部分 TensorFlow 弃用警告(提示后续版本将移除相关接口,建议使用 tf.data 等新接口替代),同时输出数据集解压信息,最终的训练迭代结果如下:

相关推荐
Coder_Boy_2 小时前
Java高级_资深_架构岗 核心面试知识点(AI整合+混合部署)
java·人工智能·spring boot·后端·面试·架构
2501_947908202 小时前
智远纳米科技量产100纳米级以下的材料引领纳米材料量产革命,形成「全球纳米材料障碍」
大数据·人工智能·科技
爱吃rabbit的mq2 小时前
第28章:MLOps基础:机器学习的DevOps
人工智能·机器学习·devops
7B_coder2 小时前
模型推理prefill和decode过程
人工智能·机器学习
阿钱真强道2 小时前
14 ThingsBoard实战:从零搭建设备配置+设备,完成MQTT温湿度上行/目标温度下行测试(对比JetLinks)
java·网络·python·网络协议
ssswywywht2 小时前
python练习
开发语言·python
PD我是你的真爱粉2 小时前
RabbitMQRPC与死信队列
后端·python·中间件
喵手2 小时前
Python爬虫实战:医院科室排班智能采集系统 - 从零构建合规且高效的医疗信息爬虫(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·医院科室排版智能采集系统·采集医疗信息·采集医疗信息sqlite存储
X54先生(人文科技)2 小时前
20260212_Meta-CreationPower_Development_Log(启蒙灯塔起源团队开发日志)
人工智能·机器学习·架构·团队开发·零知识证明