[nanoGPT] 文本生成 | 自回归采样 | `generate`方法

第六章:文本生成与采样

欢迎回来

第五章:检查点与预训练模型加载中,我们学会了如何保存模型习得的知识,甚至加载强大的预训练GPT-2模型。现在,经过所有准备和训练,终于迎来激动人心的环节:让我们的模型创作文本

想象你培养了一位才华横溢的年轻作家。你已教授其语法、词汇甚至特定文风(如莎士比亚风格)。现在,你给出开头句子,要求其创造性续写。这正是文本生成与采样的核心。

本章重点在于使用训练好的GPT模型,基于给定提示生成新颖连贯的文本。

其流程包括:输入初始文本→预测最可能的下个标记→重复追加标记直至达到目标长度。

用户可通过temperature(控制随机性)和top_k(限制候选标记范围)等参数调节生成文本的创造性。

文本生成的核心任务

核心功能是:使用训练好的GPT模型,基于初始提示生成类人文本

例如输入:"在遥远的银河系..."

模型应据此续写合理的故事段落。

核心机制:自回归采样

GPT模型通过自回归采样 逐标记生成文本,如同谨慎的作家:写一个词→思考下一个词→再写→再思考,循环往复。

基本流程:

  1. 输入提示:提供初始标记序列(开头文本)
  2. 预测下个标记 :模型根据当前序列预测下一个最可能标记
  3. 扩展序列:将预测标记追加到序列
  4. 循环执行:用新序列重复预测,直至达到目标长度

控制台:sample.py脚本

nanoGPT中负责文本生成的主脚本是sample.py。该脚本加载训练好的模型(或预训练模型)并生成文本,同时提供类似第三章:配置系统中的"调节旋钮"来控制文本风格与创造性。

典型调用方式:

bash 复制代码
python sample.py --start="你好,我叫" --num_samples=1 --max_new_tokens=50

该命令指示脚本以"你好,我叫"开头,生成1段50个标记的文本。

创造性控制参数

sample.py提供多个关键参数调节生成过程:

  • --start :初始提示文本。可以是简单字符串或文件路径(如FILE:prompt.txt
  • --max_new_tokens:控制生成的新标记数量。若提示含5个标记且设为50,最终输出为55个标记
  • --temperature :浮点数(如0.8/1.0/1.2)控制输出随机性
    • =1.0:按模型原始概率采样
    • <1.0(如0.5):输出更确定化,偏向高频词
    • >1.0(如1.2):增加非常用词概率,输出更天马行空
    • 类比:视作文本"辣度",低值保守,高值奔放
  • --top_k :整数(如5/200)限制候选标记范围
    • 仅考虑词表中前k个最可能标记
    • =1时总是选择最高概率标记,输出确定性高但易重复
    • =200时在前200个候选中选择,平衡创造性与连贯性
    • 类比:如同为作家提供精选的200个候选词

生成实例

1. 从莎士比亚微调模型生成

首先按第四章训练字符级莎士比亚模型:

bash 复制代码
python data/shakespeare_char/prepare.py
python train.py config/train_shakespeare_char.py

训练完成后(或设置always_save_checkpoint=True实时保存),在out-shakespeare-char目录生成ckpt.pt文件。生成示例:

bash 复制代码
python sample.py --out_dir=out-shakespeare-char --start="ROMEO:" --num_samples=1 --max_new_tokens=100 --temperature=0.8 --top_k=50

示例输出

复制代码
ROMEO:
I am the death of all the land.
What, art thou come? and what will be the day?
I will not be gone.

这是小型字符级模型的输出,虽不完美但遵循了"ROMEO:"模式。注意通过out_dir指定模型路径。

2. 从预训练GPT-2 XL模型生成

使用大型预训练GPT-2模型(需先按第五章配置):

bash 复制代码
python sample.py \
    --init_from=gpt2-xl \
    --start="生命、宇宙及万物的终极答案是什么?" \
    --num_samples=1 \
    --max_new_tokens=100 \
    --temperature=0.7 \
    --top_k=20

示例输出

复制代码
生命、宇宙及万物的终极答案是什么?
根据《银河系漫游指南》的记载,答案是42。

但问题在于,人们问错了问题。真正的问题应该是"哪个问题的答案是42?"。这个原问题比表面看起来复杂得多,已成为哲学界长期辩论的主题。

大型模型配合BPE分词(见第一章)能生成更高质量的文本,甚至引经据典

核心实现:generate方法

sample.py的核心是调用GPT类中的generate方法,该方法实现逐标记生成逻辑

生成流程

假设输入提示为"The cat sat":

  1. 初始化 :将提示转为标记ID序列(如[10,20,30]),获取max_new_tokens等参数
  2. 循环生成 :开始最多50次的循环(对应max_new_tokens
  3. 上下文截断 :若当前序列超模型block_size,仅保留最近部分
  4. 模型预测 :将当前序列输入模型,获取所有可能下个标记的原始分数(logits)
  5. 温度调节 :用temperature调整分数分布,高值使分布更平缓
  6. Top-K过滤 :若设置top_k,仅保留前k个高分标记
  7. 概率计算 :通过softmax将分数转为概率
  8. 标记采样:依概率随机选取一个标记
  9. 序列扩展:将新标记追加到当前序列
  10. 循环继续 :用扩展后的序列重复预测
  11. 返回结果:达到目标长度后返回完整标记序列

代码(model.py)

model.pygenerate方法的核心逻辑(使用torch.no_grad()避免计算梯度):

python 复制代码
# 摘自model.py(简化版)
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        # 1. 上下文截断
        idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

        # 2. 获取预测分数
        logits, _ = self(idx_cond)
        logits = logits[:, -1, :] # 仅取最后位置的预测

        # 3. 温度调节
        logits = logits / temperature

        # 4. Top-K过滤
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        
        # 5. 概率计算与采样
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)

        # 6. 序列扩展
        idx = torch.cat((idx, idx_next), dim=1)

    return idx
  • self(idx_cond):调用模型前向传播(见第二章),获取每个位置的预测分数
  • logits[:, -1, :]:仅保留序列末位的预测(即下一个标记)
  • torch.topk:找出前k个高分标记
  • torch.multinomial:依概率分布随机采样
  • torch.cat:将新标记追加到序列

最终sample.py通过解码器(见第一章)将标记ID序列转为可读文本

小结

我们探索了文本生成与采样的精彩世界,学会了如何:

  • 使用sample.py脚本控制生成过程
  • 通过temperaturetop_k调节文本创造性
  • 从微调模型或预训练大模型生成文本
  • 理解generate方法的自回归采样机制

现在我们的模型已能创作文本,下一站是让训练与生成过程更加高效

下一章:性能与效率工具

相关推荐
NAGNIP6 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab7 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab7 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP11 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年11 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼11 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS11 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区12 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈12 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang13 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx