[nanoGPT] 文本生成 | 自回归采样 | `generate`方法

第六章:文本生成与采样

欢迎回来

第五章:检查点与预训练模型加载中,我们学会了如何保存模型习得的知识,甚至加载强大的预训练GPT-2模型。现在,经过所有准备和训练,终于迎来激动人心的环节:让我们的模型创作文本

想象你培养了一位才华横溢的年轻作家。你已教授其语法、词汇甚至特定文风(如莎士比亚风格)。现在,你给出开头句子,要求其创造性续写。这正是文本生成与采样的核心。

本章重点在于使用训练好的GPT模型,基于给定提示生成新颖连贯的文本。

其流程包括:输入初始文本→预测最可能的下个标记→重复追加标记直至达到目标长度。

用户可通过temperature(控制随机性)和top_k(限制候选标记范围)等参数调节生成文本的创造性。

文本生成的核心任务

核心功能是:使用训练好的GPT模型,基于初始提示生成类人文本

例如输入:"在遥远的银河系..."

模型应据此续写合理的故事段落。

核心机制:自回归采样

GPT模型通过自回归采样 逐标记生成文本,如同谨慎的作家:写一个词→思考下一个词→再写→再思考,循环往复。

基本流程:

  1. 输入提示:提供初始标记序列(开头文本)
  2. 预测下个标记 :模型根据当前序列预测下一个最可能标记
  3. 扩展序列:将预测标记追加到序列
  4. 循环执行:用新序列重复预测,直至达到目标长度

控制台:sample.py脚本

nanoGPT中负责文本生成的主脚本是sample.py。该脚本加载训练好的模型(或预训练模型)并生成文本,同时提供类似第三章:配置系统中的"调节旋钮"来控制文本风格与创造性。

典型调用方式:

bash 复制代码
python sample.py --start="你好,我叫" --num_samples=1 --max_new_tokens=50

该命令指示脚本以"你好,我叫"开头,生成1段50个标记的文本。

创造性控制参数

sample.py提供多个关键参数调节生成过程:

  • --start :初始提示文本。可以是简单字符串或文件路径(如FILE:prompt.txt
  • --max_new_tokens:控制生成的新标记数量。若提示含5个标记且设为50,最终输出为55个标记
  • --temperature :浮点数(如0.8/1.0/1.2)控制输出随机性
    • =1.0:按模型原始概率采样
    • <1.0(如0.5):输出更确定化,偏向高频词
    • >1.0(如1.2):增加非常用词概率,输出更天马行空
    • 类比:视作文本"辣度",低值保守,高值奔放
  • --top_k :整数(如5/200)限制候选标记范围
    • 仅考虑词表中前k个最可能标记
    • =1时总是选择最高概率标记,输出确定性高但易重复
    • =200时在前200个候选中选择,平衡创造性与连贯性
    • 类比:如同为作家提供精选的200个候选词

生成实例

1. 从莎士比亚微调模型生成

首先按第四章训练字符级莎士比亚模型:

bash 复制代码
python data/shakespeare_char/prepare.py
python train.py config/train_shakespeare_char.py

训练完成后(或设置always_save_checkpoint=True实时保存),在out-shakespeare-char目录生成ckpt.pt文件。生成示例:

bash 复制代码
python sample.py --out_dir=out-shakespeare-char --start="ROMEO:" --num_samples=1 --max_new_tokens=100 --temperature=0.8 --top_k=50

示例输出

复制代码
ROMEO:
I am the death of all the land.
What, art thou come? and what will be the day?
I will not be gone.

这是小型字符级模型的输出,虽不完美但遵循了"ROMEO:"模式。注意通过out_dir指定模型路径。

2. 从预训练GPT-2 XL模型生成

使用大型预训练GPT-2模型(需先按第五章配置):

bash 复制代码
python sample.py \
    --init_from=gpt2-xl \
    --start="生命、宇宙及万物的终极答案是什么?" \
    --num_samples=1 \
    --max_new_tokens=100 \
    --temperature=0.7 \
    --top_k=20

示例输出

复制代码
生命、宇宙及万物的终极答案是什么?
根据《银河系漫游指南》的记载,答案是42。

但问题在于,人们问错了问题。真正的问题应该是"哪个问题的答案是42?"。这个原问题比表面看起来复杂得多,已成为哲学界长期辩论的主题。

大型模型配合BPE分词(见第一章)能生成更高质量的文本,甚至引经据典

核心实现:generate方法

sample.py的核心是调用GPT类中的generate方法,该方法实现逐标记生成逻辑

生成流程

假设输入提示为"The cat sat":

  1. 初始化 :将提示转为标记ID序列(如[10,20,30]),获取max_new_tokens等参数
  2. 循环生成 :开始最多50次的循环(对应max_new_tokens
  3. 上下文截断 :若当前序列超模型block_size,仅保留最近部分
  4. 模型预测 :将当前序列输入模型,获取所有可能下个标记的原始分数(logits)
  5. 温度调节 :用temperature调整分数分布,高值使分布更平缓
  6. Top-K过滤 :若设置top_k,仅保留前k个高分标记
  7. 概率计算 :通过softmax将分数转为概率
  8. 标记采样:依概率随机选取一个标记
  9. 序列扩展:将新标记追加到当前序列
  10. 循环继续 :用扩展后的序列重复预测
  11. 返回结果:达到目标长度后返回完整标记序列

代码(model.py)

model.pygenerate方法的核心逻辑(使用torch.no_grad()避免计算梯度):

python 复制代码
# 摘自model.py(简化版)
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        # 1. 上下文截断
        idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

        # 2. 获取预测分数
        logits, _ = self(idx_cond)
        logits = logits[:, -1, :] # 仅取最后位置的预测

        # 3. 温度调节
        logits = logits / temperature

        # 4. Top-K过滤
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        
        # 5. 概率计算与采样
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        # 6. 序列扩展
        idx = torch.cat((idx, idx_next), dim=1)

    return idx
  • self(idx_cond):调用模型前向传播(见第二章),获取每个位置的预测分数
  • logits[:, -1, :]:仅保留序列末位的预测(即下一个标记)
  • torch.topk:找出前k个高分标记
  • torch.multinomial:依概率分布随机采样
  • torch.cat:将新标记追加到序列

最终sample.py通过解码器(见第一章)将标记ID序列转为可读文本

小结

我们探索了文本生成与采样的精彩世界,学会了如何:

  • 使用sample.py脚本控制生成过程
  • 通过temperaturetop_k调节文本创造性
  • 从微调模型或预训练大模型生成文本
  • 理解generate方法的自回归采样机制

现在我们的模型已能创作文本,下一站是让训练与生成过程更加高效

下一章:性能与效率工具

相关推荐
用户5191495848454 分钟前
C#扩展成员全面解析:从方法到属性的演进
人工智能·aigc
柳鲲鹏6 分钟前
OpenCV: 光流法python代码
人工智能·python·opencv
金融小师妹28 分钟前
基于LSTM-GARCH模型:三轮黄金周期特征提取与多因子定价机制解构
人工智能·深度学习·1024程序员节
小蜜蜂爱编程30 分钟前
深度学习实践 - 使用卷积神经网络的手写数字识别
人工智能·深度学习·cnn
leiming634 分钟前
深度学习日记2025.11.20
人工智能·深度学习
速易达网络44 分钟前
tensorflow+yolo图片训练和图片识别系统
人工智能·python·tensorflow
智元视界1 小时前
从算法到城市智能:AI在马来西亚智慧城市建设中的系统应用
人工智能·科技·智慧城市·数字化转型·产业升级
Tezign_space1 小时前
技术方案|构建品牌KOS内容中台:三种架构模式与AI赋能实践
人工智能·架构·数字化转型·小红书·kos·内容营销·内容科技
嵌入式-老费1 小时前
自己动手写深度学习框架(pytorch训练第一个网络)
人工智能·pytorch·深度学习
小刘摸鱼中1 小时前
高频电子电路-振荡器的频率稳定度
网络·人工智能