【零基础学AI】第26讲:循环神经网络(RNN)与LSTM - 文本生成

本节课你将学到

  • 理解RNN的基本原理和局限性
  • 掌握LSTM的结构和工作机制
  • 学会使用TensorFlow构建文本生成模型
  • 实现一个莎士比亚风格文本生成器

开始之前

环境要求

  • Python 3.8+
  • 需要安装的包:
    • tensorflow==2.8.0
    • numpy==1.21.0
    • matplotlib==3.4.0

前置知识

  • 神经网络基础(第23讲)
  • 文本处理基础(第14讲文本分类)
  • 序列数据处理概念

核心概念

为什么需要RNN?

想象你在读一本小说:

传统神经网络的缺陷

  • 每个单词独立处理,无法记住前文
  • 固定输入尺寸,无法处理变长文本
  • 无法捕捉时间序列中的依赖关系

RNN的解决方案

  • 循环连接:保留隐藏状态传递历史信息
  • 变长输入:理论上可处理任意长度序列
  • 时间展开:每个时间步共享相同权重

RNN基本结构

复制代码
时间步展开:
X₀ → [RNN单元] → h₀ → y₀
X₁ → [RNN单元] → h₁ → y₁ 
X₂ → [RNN单元] → h₂ → y₂
...

关键组件:

  • 隐藏状态(h):携带历史信息的"记忆"
  • 权重共享:所有时间步使用相同的U/W/V参数
  • 输出计算h_t = tanh(U·x_t + W·h_{t-1} + b)

RNN的梯度消失问题

当网络较深时(时间步很多):

  • 梯度通过多个tanh函数连乘
  • 梯度指数级缩小 → 早期时间步无法有效学习
  • 导致RNN难以学习长距离依赖

LSTM(长短期记忆网络)

就像一个有记忆管理系统的智能笔记本:

三个控制门

  1. 遗忘门:决定丢弃哪些历史信息
  2. 输入门:决定存储哪些新信息
  3. 输出门:决定输出哪些信息

细胞状态©:贯穿整个时间步的"记忆通道"

数学表示:

复制代码
遗忘门:f_t = σ(W_f·[h_{t-1}, x_t] + b_f)
输入门:i_t = σ(W_i·[h_{t-1}, x_t] + b_i)
候选值:C̃_t = tanh(W_C·[h_{t-1}, x_t] + b_C)
细胞状态:C_t = f_t * C_{t-1} + i_t * C̃_t
输出门:o_t = σ(W_o·[h_{t-1}, x_t] + b_o)
隐藏状态:h_t = o_t * tanh(C_t)

文本生成原理

字符级文本生成流程

  1. 输入种子文本(如"The cat sat on the")
  2. 预测下一个字符的概率分布
  3. 从分布中采样一个字符(如'a')
  4. 将新字符添加到输入,重复过程

关键参数

  • 温度(Temperature) :控制生成随机性
    • 高温 → 更多随机性/创造性
    • 低温 → 更保守/可预测

代码实战

1. 准备莎士比亚文本数据

python 复制代码
import tensorflow as tf
import numpy as np
import os
import time

# 下载莎士比亚文本
path_to_file = tf.keras.utils.get_file(
    'shakespeare.txt',
    'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt'
)

# 读取并预览数据
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
print(f'文本长度: {len(text)} 字符')
print(text[:250])  # 打印前250个字符

# 创建字符到ID的映射
vocab = sorted(set(text))
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

# 文本向量化
text_as_int = np.array([char2idx[c] for c in text])
print(f'{text[:13]} → {text_as_int[:13]}')

2. 创建训练样本和批次

python 复制代码
# 创建训练样本(输入序列和目标序列)
seq_length = 100  # 每个输入序列的长度
examples_per_epoch = len(text) // (seq_length + 1)

char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)

sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)

# 定义输入到目标的映射函数
def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

# 批处理配置
BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

3. 构建LSTM模型

python 复制代码
# 模型参数
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.Sequential([
        # 嵌入层:将字符ID转换为密集向量
        tf.keras.layers.Embedding(
            vocab_size,
            embedding_dim,
            batch_input_shape=[batch_size, None]
        ),
        
        # LSTM层
        tf.keras.layers.LSTM(
            rnn_units,
            return_sequences=True,  # 返回完整序列
            stateful=True,          # 保持批次间状态
            recurrent_initializer='glorot_uniform'
        ),
        
        # 输出层:预测下一个字符的概率
        tf.keras.layers.Dense(vocab_size)
    ])
    return model

model = build_model(
    vocab_size=len(vocab),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units,
    batch_size=BATCH_SIZE
)

model.summary()

4. 训练模型

python 复制代码
# 定义损失函数
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(
        labels, logits, from_logits=True)

# 编译模型
model.compile(optimizer='adam', loss=loss)

# 配置检查点保存
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

# 训练模型
EPOCHS = 30
history = model.fit(
    dataset,
    epochs=EPOCHS,
    callbacks=[checkpoint_callback]
)

5. 文本生成函数

python 复制代码
# 恢复最新检查点
tf.train.latest_checkpoint(checkpoint_dir)
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))

# 文本生成函数
def generate_text(model, start_string, num_generate=1000, temperature=1.0):
    # 将起始字符串转换为数字(向量化)
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    
    # 空列表存储结果
    text_generated = []
    
    # 重置模型状态
    model.reset_states()
    
    for i in range(num_generate):
        # 预测下一个字符
        predictions = model(input_eval)
        # 移除批次维度
        predictions = tf.squeeze(predictions, 0)
        
        # 使用温度参数重新缩放logits
        predictions = predictions / temperature
        # 采样预测字符ID
        predicted_id = tf.random.categorical(
            predictions, num_samples=1)[-1,0].numpy()
        
        # 将预测字符和之前的隐藏状态一起传递给模型作为下一步输入
        input_eval = tf.expand_dims([predicted_id], 0)
        
        # 将预测字符添加到生成文本
        text_generated.append(idx2char[predicted_id])
    
    return (start_string + ''.join(text_generated))

# 生成文本示例
print(generate_text(
    model,
    start_string="ROMEO: ",
    temperature=0.7,
    num_generate=500
))

6. 温度参数对比实验

python 复制代码
# 不同温度下的生成效果对比
temperatures = [0.2, 0.7, 1.2]  # 保守 → 平衡 → 随机

for temp in temperatures:
    print(f"\n=== 温度 {temp} ===")
    print(generate_text(
        model,
        start_string="QUEEN:",
        temperature=temp,
        num_generate=200
    ))

完整项目

项目结构

复制代码
lesson_26_rnn_lstm/
├── README.md
├── requirements.txt
├── text_generation.py     # 主程序文件
├── training_checkpoints/  # 训练保存点
├── utils/
│   ├── text_utils.py      # 文本处理工具
│   └── model_utils.py     # 模型工具
└── output/                # 生成示例
    ├── training_curve.png
    └── generated_text.txt

requirements.txt

txt 复制代码
tensorflow==2.8.0
numpy==1.21.0
matplotlib==3.4.0

text_generation.py

python 复制代码
import tensorflow as tf
import numpy as np
import os
import time
from utils.text_utils import load_and_preprocess_text
from utils.model_utils import build_lstm_model, generate_text
import matplotlib.pyplot as plt

def main():
    # 加载和预处理数据
    text, vocab, char2idx, idx2char = load_and_preprocess_text()
    text_as_int = np.array([char2idx[c] for c in text])
    
    # 创建训练数据集
    seq_length = 100
    char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
    sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)
    
    def split_input_target(chunk):
        input_text = chunk[:-1]
        target_text = chunk[1:]
        return input_text, target_text
    
    dataset = sequences.map(split_input_target)
    
    # 批处理配置
    BATCH_SIZE = 64
    BUFFER_SIZE = 10000
    dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
    
    # 构建模型
    vocab_size = len(vocab)
    embedding_dim = 256
    rnn_units = 1024
    
    model = build_lstm_model(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        rnn_units=rnn_units,
        batch_size=BATCH_SIZE
    )
    
    # 训练模型
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(optimizer='adam', loss=loss_fn)
    
    # 设置检查点
    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_weights_only=True)
    
    EPOCHS = 30
    history = model.fit(
        dataset,
        epochs=EPOCHS,
        callbacks=[checkpoint_callback]
    )
    
    # 保存训练曲线
    plt.plot(history.history['loss'])
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig('output/training_curve.png')
    
    # 文本生成示例
    model = build_lstm_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
    model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
    model.build(tf.TensorShape([1, None]))
    
    generated_text = generate_text(
        model,
        start_string="ROMEO: ",
        temperature=0.7,
        num_generate=1000
    )
    
    # 保存生成文本
    with open('output/generated_text.txt', 'w') as f:
        f.write(generated_text)
    
    print("\n=== 生成文本示例 ===")
    print(generated_text[:500] + "...")

if __name__ == "__main__":
    main()

运行效果

控制台输出

复制代码
Epoch 1/30
172/172 [==============================] - 15s 75ms/step - loss: 2.7243
Epoch 2/30
172/172 [==============================] - 13s 75ms/step - loss: 1.9912
...
Epoch 30/30
172/172 [==============================] - 13s 75ms/step - loss: 1.2564

=== 生成文本示例 ===
ROMEO: I pray thee, good Mercutio, let's retire:
The day is hot, the Capulets abroad,
And, if we meet, we shall not scape a brawl;
For now, these hot days, is the mad blood stirring.

MERCUTIO: Thou art like one of those fellows that when he
enters the confines of a tavern claps me his sword
upon the table and says 'God send me no need of thee!'
and by the operation of the second cup draws
it on the drawer, when indeed there is no need.

BENVOLIO: Am I like such a fellow?

MERCUTIO: Come, come, thou art as hot a Jack in thy mood as
any in Italy, and as soon moved to be moody, and as
soon moody to be moved...

生成的文件

  • training_checkpoints/: 训练过程中的模型权重
  • output/training_curve.png: 训练损失曲线
  • output/generated_text.txt: 生成的莎士比亚风格文本

预期结果说明

  1. 训练损失应稳定下降:表明模型在学习文本模式
  2. 生成文本应有合理结构:包含角色对话、标点等
  3. 温度参数影响明显
    • 低温:更保守、重复性更高
    • 高温:更随机、可能有语法错误

常见问题

Q1: 如何提高生成文本质量?

改进方法:

  • 增加训练数据量
  • 使用更深的LSTM网络或多层LSTM
  • 尝试不同的温度参数
  • 延长训练时间

Q2: 为什么我的模型不收敛?

可能原因:

  • 学习率不合适(尝试调整Adam优化器的学习率)
  • 梯度爆炸(添加梯度裁剪)
  • 模型容量不足(增加LSTM单元数)

Q3: 如何应用到中文文本?

调整方案:

  1. 使用中文语料库
  2. 考虑使用分词后的结果
  3. 可能需要更大的嵌入维度
  4. 调整序列长度参数

Q4: 训练速度太慢怎么办?

优化建议:

  • 减小批处理大小
  • 使用GPU加速训练
  • 降低模型复杂度
  • 使用混合精度训练

课后练习

基础练习

  • 调整seq_length参数,观察对模型的影响
  • 尝试不同的温度参数,比较生成效果
  • 修改LSTM单元数量,评估模型性能变化

进阶挑战

  • 实现基于单词而非字符的文本生成
  • 添加第二个LSTM层创建更深网络
  • 使用GRU单元替代LSTM进行比较

项目扩展

  • 开发一个交互式文本生成Web应用
  • 训练专业领域文本生成器(如法律、医疗)
  • 结合Flask创建文本生成API服务

技术总结

通过本讲我们掌握了:

  1. RNN的基本原理和局限性
  2. LSTM网络的结构和优势
  3. 文本数据的预处理方法
  4. 使用TensorFlow构建文本生成模型
  5. 温度参数对生成效果的影响

RNN和LSTM是处理序列数据的强大工具,这些知识将为你学习更复杂的序列模型(如Transformer)奠定基础。

相关推荐
W.KN1 小时前
机器学习【二】KNN
人工智能·机器学习
糖葫芦君3 小时前
玻尔兹曼分布与玻尔兹曼探索
人工智能·算法·机器学习
TT-Kun3 小时前
PyTorch基础——张量计算
人工智能·pytorch·python
Monkey-旭6 小时前
Android Bitmap 完全指南:从基础到高级优化
android·java·人工智能·计算机视觉·kotlin·位图·bitmap
天若有情6737 小时前
【python】Python爬虫入门教程:使用requests库
开发语言·爬虫·python·网络爬虫·request
哪 吒8 小时前
OpenAI放大招:ChatGPT学习模式上线,免费AI智能家教
人工智能·学习·ai·chatgpt·gemini·deepseek
IT北辰8 小时前
用Python+MySQL实战解锁企业财务数据分析
python·mysql·数据分析
Lucky高8 小时前
selenium(WEB自动化工具)
python
老鱼说AI8 小时前
循环神经网络RNN原理精讲,详细举例!
人工智能·rnn·深度学习·神经网络·自然语言处理·语音识别
秃然想通8 小时前
掌握Python三大语句:顺序、条件与循环
开发语言·python·numpy