【人工智障生成日记1】从零开始训练本地小语言模型


🎯 从零开始训练本地小语言模型:MiniGPT + TinyStories(4090Ti)

🧭 项目背景

本项目旨在以学习为目的,从头构建一个完整的本地语言模型训练管线。目标是:

  • ✅ 不依赖外部云计算
  • ✅ 完全本地运行(RTX 4090Ti)
  • ✅ 从零构建数据加载、模型结构、训练与推理逻辑
  • ✅ 阶段性掌握 LLM 微调与部署的关键技能

🛠️ 开发环境

项目 配置
操作系统 Windows 10
GPU NVIDIA RTX 4090Ti
CUDA 驱动 版本 12.1(cu121
Python 版本 3.10
虚拟环境 .venv310(指定 Python 3.10)

📦 项目结构

text 复制代码
toy-transformer/
├── data_loader.py        # 加载 TinyStories 数据集
├── model.py              # MiniGPT 模型实现
├── train_resume.py       # 支持断点训练的主循环
├── generate.py           # 推理与生成函数
├── checkpoint_latest.pth # 自动保存的训练权重
├── .venv310/             # 虚拟环境

🧠 技术路线

1. 数据加载

  • 使用 HuggingFace datasets 加载 TinyStories
  • Tokenizer 使用 GPT-2 默认分词器
  • 启用 paddingtruncation,统一 max_length=128

2. 模型构建

  • 自定义实现 MiniGPT

    • 小型 Transformer(Embedding + 多层 Self-Attention + Linear head)
    • 使用 GPT-2 的 vocab
    • 无 pretraining,全从零学起

3. 模型训练

  • 使用 torch.nn.CrossEntropyLoss,忽略 pad_token_id
  • 优化器为 AdamW
  • 使用 PyTorch AMP (torch.amp.autocast) 启用混合精度
  • 使用 GradScaler 动态控制精度
  • 使用 tqdm 进度条显示训练状态
  • 支持自动保存断点(checkpoint_latest.pth

4. 推理逻辑

  • 自定义 generate_text() 实现逐 token 自回归生成
  • 使用 softmax + sampling,多轮测试发现 collapse 问题

✅ 成果展示

训练状态

  • 成功训练了 10000 steps,loss 降至 0.05 以下
  • 支持断点恢复训练,训练速度约 25 it/s

推理结果(初步)

Prompt: "Once upon a time"

Output: time time time time time time...

Prompt: "Deng Chao is"

Output: is is is is is is is is is...

🚨 模型已出现 token collapse 问题(复读)------为后续结构调整和防过拟合提供重要依据。


🚧 存在问题与后续优化

问题 原因 计划
文本输出高度重复 模型 collapse,过拟合高频 token 使用 dropout,top-k sampling,减少训练步数
模型容量不足 结构太小,表达力差 增加 embedding 和层数;或微调现成模型
学不到语义 没有预训练语料带来的 inductive bias 尝试 distilgpt2 + TinyStories 微调

📅 今日工作小结

  • ✅ 成功建立 GPU 环境并启用 CUDA(Python 3.10 + PyTorch cu121)
  • ✅ 完整实现数据加载、tokenizer、模型结构与训练管线
  • ✅ 训练了 10000 steps,保存了断点模型
  • ✅ 实现推理接口并测试多个 prompt
  • ✅ 发现模型陷入 token collapse,为后续优化积累经验

🚀 明日计划

  • 替换 softmax 采样为 top-k + temperature 采样,缓解 collapse
  • 加入 dropout / layernorm 防止模型坍缩
  • 尝试使用 distilgpt2 微调,比较从零训练 vs 微调效果
  • 训练可视化(tensorboard / wandb)

🧨 本地语言模型训练踩坑记录(2025.05)


🧱 环境搭建相关

❌ 坑 1:安装了 PyTorch,但无法使用 CUDA

  • 表现torch.cuda.is_available() 返回 False,训练跑在 CPU 上,GPU 利用率为 0%

  • 原因:初始环境为 Python 3.13,而 PyTorch GPU 版本不支持该版本

  • 解决方案

    1. 安装 Python 3.10(py -3.10

    2. 使用 python3.10 -m venv .venv310 创建虚拟环境

    3. 使用官方源安装支持 CUDA 的 PyTorch:

      bash 复制代码
      pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

🧪 数据处理与训练相关

❌ 坑 2:模型在 GPU 上训练速度不升反降

  • 表现:CUDA 激活后训练反而更慢

  • 原因 :虽然模型 to(device),但输入数据没有显式 .to(device)

  • 解决方案

    • 使用:

      python 复制代码
      input_ids = batch["input_ids"].to(device)

      而不是:

      python 复制代码
      batch["input_ids"].to(device)  # ⚠ 无效!

❌ 坑 3:tqdm 报错 IProgress not found

  • 表现 :使用 from tqdm.notebook import tqdm 报错

  • 解决方案

    • 快速替换为:

      python 复制代码
      from tqdm import tqdm
    • 或安装依赖:

      bash 复制代码
      pip install ipywidgets
      jupyter nbextension enable --py widgetsnbextension

📦 模型训练相关

❌ 坑 4:训练 loss 降不下来 or 降到 0.0000 太快

  • 表现:训练 1 个 epoch 后 loss ≈ 0,后续 epoch 训练跳过

  • 原因 :训练步数被 step_count >= max_steps 提前终止,epoch 实际未执行

  • 解决方案

    • 使用 total_step 替代 step_count 并每轮累加
    • 或改为基于 max_epochs 控制训练轮数

❌ 坑 5:训练后模型生成"词语复读机"(collapse)

  • 表现 :生成 output 全是 "time time time...""is is is..."

  • 原因

    • 模型太小,表达能力差
    • 学习率太大或步数太多导致过拟合高频 token
  • 解决方案

    • 启用 dropout 正则
    • 使用 top-k + temperature 控制采样策略
    • 更换为 distilgpt2 微调方案或扩大学习语料

🔐 安全性提示

❌ 坑 6:PyTorch 警告 torch.load() 存在安全隐患

  • 表现 :加载 checkpoint 时出现 FutureWarning: weights_only=False

  • 解决方案(建议但非必须):

    • 明确添加参数:

      python 复制代码
      torch.load(checkpoint_path, weights_only=True)

相关推荐
摆烂仙君1 小时前
LoRA(Low-Rank Adaptation)
人工智能·计算机视觉
杰瑞学AI2 小时前
深度学习中的分布偏移问题及其解决方法
人工智能·深度学习·机器学习·ai
摩尔线程2 小时前
推测解码算法在 MTT GPU 的应用实践
算法·语言模型·大模型·gpu算力·gpu·摩尔线程
学算法的程霖2 小时前
分享|16个含源码和数据集的计算机视觉实战项目
人工智能·pytorch·深度学习·机器学习·计算机视觉·目标跟踪·研究生
带电的小王2 小时前
【动手学深度学习】2.3. 线性代数
人工智能·深度学习·线性代数
Listennnn2 小时前
点云(point cloud):自动驾驶的“三维扫描图“
人工智能·机器学习·自动驾驶
土拨鼠不是老鼠2 小时前
windows 下用yolov5 训练模型 给到opencv 使用
人工智能·opencv·yolo
小橘子就是小橘子2 小时前
9大开源AI智能体概况
人工智能·开源·ai agent
moonsims2 小时前
无人机桥梁检测如何通过数据存储、边缘AI、无线通讯等技术路线,提升检测效率
人工智能
moonsims3 小时前
无人机桥梁巡检
人工智能