第45篇:文本生成实战:使用GPT-2创作故事——体验AI的“创造力”(项目实战)

文章目录

项目背景

在之前的项目中,我们处理的多是分类、预测等"理解型"任务。这次,我想带大家玩点不一样的------让AI"创造"一个故事。文本生成,尤其是开放式故事创作,是检验语言模型"智能"程度的绝佳试金石。我记得第一次用GPT-2生成文本时,那种看着连贯、甚至富有想象力的句子从模型里"流"出来的感觉,非常震撼。它不再是简单地复述,而是在"编造"。本项目我们就用Hugging Face transformers库,基于预训练的GPT-2模型,从零搭建一个故事生成器,亲身体验AI的"创造力"边界。

技术选型

为什么选择GPT-2,而不是更新、更大的GPT-3或GPT-4?这基于几个务实的考虑:

  1. 资源友好:GPT-2(特别是小号和小小号版本)可以在消费级GPU甚至CPU上运行,而GPT-3/4的API调用有成本,且大模型本地部署门槛极高。
  2. 开源可控:GPT-2完全开源,我们可以深入模型内部,调整生成策略,进行微调,学习整个流程。这对于理解和掌握文本生成技术至关重要。
  3. 效果足够:对于故事生成这个场景,GPT-2(特别是774M参数版本)的能力已经能产生令人惊喜的结果,足以让我们体验核心乐趣和技术要点。

因此,我们的技术栈非常清晰:

  • 核心模型 :Hugging Face transformers库中的 gpt2 (或 gpt2-medium)
  • 深度学习框架:PyTorch
  • 辅助工具torch, transformers, tqdm (用于进度条)

架构设计

这个项目的架构非常简单直接,是一个典型的"预训练模型+生成策略"流水线。我们不涉及复杂的服务部署,聚焦于生成逻辑本身。

复制代码
用户输入(故事开头/提示词)
        ↓
[文本预处理与Tokenization]
        ↓
[加载预训练GPT-2模型与分词器]
        ↓
[核心文本生成循环]
        ├── 策略选择:贪婪搜索、集束搜索、Top-k采样、Top-p采样
        └── 生成控制:最大长度、重复惩罚、温度参数
        ↓
[Token解码为文本]
        ↓
输出生成的故事段落

核心在于生成循环解码策略。不同的策略会极大影响生成故事的"创造性"、"连贯性"和"可读性"。

核心实现

让我们一步步用代码实现这个生成器。首先,确保环境已安装必要库:pip install transformers torch tqdm

1. 准备模型与分词器

python 复制代码
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

# 加载预训练模型和分词器
# 使用 `gpt2` (124M参数) 以在CPU或内存较小的GPU上快速运行
# 如果想效果更好,可以尝试 `gpt2-medium` (774M参数)
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# 将模型设置为评估模式(关闭dropout等训练层)
model.eval()

# 如果可用,使用GPU加速
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

踩坑提示1 :GPT-2的分词器默认不会添加padding token。如果后续需要批量生成,需要手动设置:tokenizer.pad_token = tokenizer.eos_token。我们本次是单条生成,可暂不处理。

2. 构建文本生成函数

这是项目的核心。我们将实现一个支持多种解码策略的生成函数。

python 复制代码
def generate_story(prompt, max_length=150, num_return_sequences=1, strategy='top_p', temperature=1.0, top_k=50, top_p=0.95):
    """
    使用GPT-2生成故事续写。
    
    参数:
        prompt: 故事开头的提示文本。
        max_length: 生成文本的最大总长度(包括提示)。
        num_return_sequences: 生成几个不同的故事版本。
        strategy: 解码策略,可选 'greedy', 'beam', 'top_k', 'top_p'。
        temperature: 温度参数,越高越随机,越低越确定。
        top_k: Top-k采样中的k值。
        top_p: Top-p(核)采样中的p值。
    """
    # 将提示文本编码为模型输入的token ID
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    # 根据策略配置生成参数
    generation_config = {
        'max_length': max_length,
        'num_return_sequences': num_return_sequences,
        'pad_token_id': tokenizer.eos_token_id, # 设置结束符也为padding符
        'do_sample': True, # 默认启用采样,对于贪婪和集束搜索会覆盖
        'temperature': temperature,
    }
    
    if strategy == 'greedy':
        # 贪婪搜索:每一步都选择概率最高的词
        generation_config.update({'do_sample': False, 'num_beams': 1})
    elif strategy == 'beam':
        # 集束搜索:每一步保留多个最有可能的序列
        generation_config.update({'do_sample': False, 'num_beams': 5, 'early_stopping': True})
    elif strategy == 'top_k':
        # Top-k采样:每一步从概率最高的k个词中随机选一个
        generation_config.update({'top_k': top_k})
    elif strategy == 'top_p':
        # Top-p(核)采样:每一步从累积概率超过p的最小词集合中随机选一个
        generation_config.update({'top_p': top_p, 'top_k': 0}) # top_k=0表示禁用top-k
    
    # 使用模型生成文本
    with torch.no_grad(): # 禁用梯度计算,加快推理速度,减少内存占用
        output_sequences = model.generate(
            input_ids=input_ids,
            **generation_config
        )
    
    # 解码生成的token ID为可读文本
    generated_stories = []
    for generated_sequence in output_sequences:
        # 跳过输入提示部分,只解码新生成的部分
        generated_sequence = generated_sequence[len(input_ids[0]):]
        text = tokenizer.decode(generated_sequence, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        generated_stories.append(text)
    
    return generated_stories

3. 体验不同的生成策略

现在,让我们用一个简单的提示词,对比不同策略的效果。

python 复制代码
prompt = "在一个遥远的未来,机器人学会了做梦。"
print(f"提示词: {prompt}\n")

strategies = ['greedy', 'beam', 'top_k', 'top_p']
for s in strategies:
    print(f"--- 使用策略: {s.upper()} ---")
    stories = generate_story(prompt, max_length=100, strategy=s, temperature=0.8 if s in ['top_k', 'top_p'] else 1.0)
    print(stories[0][:200] + "...") # 打印前200个字符
    print()

运行这段代码,你会直观地看到:

  • 贪婪搜索(Greedy):生成的故事通常最连贯、最安全,但也最容易陷入重复循环(比如不停地重复"机器人做梦,做梦,做梦......"),缺乏新意。
  • 集束搜索(Beam Search):连贯性比贪婪搜索更好,能一定程度避免重复,但生成的故事可能过于"平庸"或模板化。
  • Top-k / Top-p 采样 :这是让故事变得"有趣"和"有创意"的关键。通过引入随机性,生成的情节往往更出人意料。Top-p(核采样)通常比Top-k更灵活和有效,因为它动态调整候选词集合的大小。

踩坑记录

在实际操作中,我遇到了几个典型问题,这里分享给大家:

  1. 生成结果重复或退化:这是文本生成的经典难题。表现为模型开始不断重复同一句话或词语。

    • 解决方案 :除了使用Top-p采样,还可以在generate函数中设置repetition_penalty参数(大于1.0,如1.2),对已出现过的token进行概率惩罚。或者使用no_repeat_ngram_size参数禁止特定长度的短语重复出现。
  2. 生成内容无关或跑题:模型可能会从"机器人做梦"突然跳到谈论"今天的天气"。

    • 解决方案 :这通常与提示词不够具体有关。尝试给出更详细、更具约束性的开头。例如:"在一个遥远的未来,一个负责清理城市的旧型号机器人,第一次在待机时体验到了类似人类做梦的数据流。它梦见了:"。此外,适当降低temperature值(如从1.0降到0.7)可以让生成内容更聚焦。
  3. 生成速度慢 :尤其是在使用beam search或生成长文本时。

    • 解决方案 :对于交互式应用,可以考虑使用更小的模型(distilgpt2)。在生成时,使用max_new_tokens参数替代max_length,精确控制新生成长度,避免不必要的计算。如果使用支持CUDA的GPU,确保模型和输入数据都已.to(device)
  4. 奇怪的分词和空格:生成文本中可能出现奇怪的符号或多余空格。

    • 解决方案 :确保在tokenizer.decode()时设置了skip_special_tokens=Trueclean_up_tokenization_spaces=True。对于中文或其他语言,可能需要使用专门的分词器(如gpt2-chinese)。

效果对比与项目扩展

经过多次尝试,我发现在故事生成任务上,"Top-p采样(p值0.9-0.95)配合适当的温度(0.7-0.9)" 是平衡创造性、连贯性和趣味性的最佳组合。贪婪和集束搜索更适合需要高确定性的任务,如代码补全或翻译。

项目扩展思路

  1. 微调模型:找一些科幻小说、童话故事的文本数据,对GPT-2进行微调,让它生成特定风格的故事。
  2. 构建Web应用:使用Gradio或Streamlit快速搭建一个交互界面,让用户输入提示词,选择风格,实时生成故事。
  3. 多轮对话式生成:模拟一个"AI说书人",用户输入"然后呢?"来推动故事发展,这需要维护一个不断增长的对话历史上下文。
  4. 加入条件控制:使用CTRL或Prompt Tuning等技术,控制故事的情感(悲伤/欢乐)、流派(科幻/武侠)等属性。

通过这个实战项目,我们不仅运行了一个GPT-2模型,更重要的是,我们亲手调试了那些控制AI"创造力"的旋钮------温度、Top-p、重复惩罚等。你会发现,所谓的AI创造力,目前很大程度上是"可控的随机性"。如何设置这些参数,让生成的故事既天马行空又不至于胡言乱语,正是这门技术的艺术所在。

动手试试吧,给你的AI一个开头,看看它会还你一个怎样的世界。

如有问题欢迎评论区交流,持续更新中...

相关推荐
程序员老邢1 小时前
【产品底稿 08】商助慧 AI 仿写实战复盘:RAG 知识库 + 大模型联动,一键生成技术底稿
人工智能·spring boot·后端·ai·语言模型·milvus
IT_陈寒1 小时前
JavaScript的闭包差点让我加班到凌晨
前端·人工智能·后端
AI服务老曹1 小时前
打破设备割裂:基于 GB28181 与 RTSP 的边缘计算 AI 视频平台架构解析(附源码交付与 Docker 部署)
人工智能·音视频·边缘计算
老王谈企服1 小时前
流程型制造业生产优化,未来将如何被大模型技术重构?2026智造深研:实在Agent驱动端到端生产闭环
大数据·网络·人工智能·ai·重构
老赵聊算法、大模型备案1 小时前
从剪映、即梦 AI 被罚,读懂 AI 生成内容标识硬性合规要求
人工智能·算法·安全·aigc
传说故事1 小时前
【论文阅读】通过homeostasis RL学习合成综合机器人行为
论文阅读·人工智能·机器人·具身智能
zhangfeng11331 小时前
LLaMA-Factory 保存 checkpoint 时崩溃解决办法 OOM 内存溢出(不是显存)
运维·服务器·人工智能·深度学习·llama
小程故事多_801 小时前
DeepSeek-V4技术报告全解读 从架构到Infra的全栈重构之路
人工智能·重构·架构·智能体
数智工坊1 小时前
【VarifocalNet(VFNet)论文阅读】:IoU-aware稠密目标检测,把定位质量塞进分类得分
论文阅读·人工智能·深度学习·目标检测·计算机视觉·分类·cnn