MinMind
MiniMind 是一个面向学习和研究的极简大语言模型项目,适合初学者快速入门 LLM 的训练和部署。其轻量化设计和低成本实现是显著优势,但在模型能力、数据规模和多语言支持上存在一定局限性,更适合用作学习和实验。
MiniMind 优点
-
极简设计,低成本复现:
- 从零开始实现大语言模型,代码简洁明了,适合学习和研究。
- 仅需 2 小时 + 3 元成本即可训练出基础模型,门槛极低。
-
轻量化模型:
- 模型体积小,最低仅 25.8M,适合个人 GPU 设备快速训练和部署。
- 支持 Dense 和 MoE(混合专家)模型,提供多种参数规模选择。
-
全流程开源:
- 包含从预训练、监督微调(SFT)、LoRA 微调、RLHF(DPO)、模型蒸馏等全过程代码。
- 数据集清洗、分词器训练、推理服务等全链路开源,覆盖 LLM 的完整生命周期。
-
学习友好:
- 不依赖高度封装的第三方库,完全基于 PyTorch 原生实现,便于理解底层逻辑。
- 提供详细的训练步骤和教程,适合作为 LLM 入门学习的参考项目。
-
多样化支持:
- 支持单机单卡、单机多卡(DDP、DeepSpeed)训练,兼容主流框架(如 transformers、trl、peft)。
- 提供 OpenAI API 兼容的服务端,便于集成到第三方应用。
-
丰富的实验与数据:
- 提供多种高质量数据集(预训练、SFT、RLHF、蒸馏等),并开源数据清洗流程。
- 详细的实验记录和性能对比,便于复现和优化。
-
社区友好:
- 鼓励开源社区参与,提供详细的贡献者列表和鸣谢。
- 支持多种模型格式(PyTorch 原生、Transformers),便于使用和扩展。
MiniMind 缺点
-
模型能力有限:
- 由于模型参数较小(最低 25.8M),在复杂任务上的表现有限,难以与大规模模型(如 GPT-3、Llama)媲美。
- 推理能力和生成质量受限,尤其在长文本生成和复杂推理任务上表现不足。
-
数据规模不足:
- 预训练数据集规模较小(约 1.6GB),可能导致模型知识覆盖面有限。
- 缺乏针对性优化,部分任务(如推理、逻辑性)表现不佳。
-
幻觉问题:
- 在生成任务中,模型可能出现幻觉(生成不准确或不真实的内容),尤其在知识密集型问题上。
-
缺乏多语言支持:
- 主要针对中文优化,英文能力较弱,难以满足多语言场景需求。
-
性能瓶颈:
- 尽管支持 MoE 模型,但在大规模分布式训练和推理效率上仍有提升空间。
- 在主流基准测试(如 C-Eval、CMMLU)中表现一般,难以与更大规模模型竞争。
-
社区生态较小:
- 相较于 Hugging Face 等成熟生态,MiniMind 的社区规模和资源相对有限。
MinMind 安装
源码
bash
git clone https://github.com/jingyaogong/minimind.git
环境准备
ruby
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
下载模型
bash
git clone https://hf-mirror.com/jingyaogong/MiniMind2
命令行问答
lua
# load=0: load from pytorch model, load=1: load from transformers-hf model
python eval_model.py --load 1 --model_mode 2
命令行提供了两种方式选择
自动测试
手动输入
MiniMind 命令行问答执行逻辑/eval_model.py
%% Mermaid 流程图
%% 美化样式:使用主题和节点样式
flowchart TD
A[启动程序] --> B[解析命令行参数]
B --> C[初始化模型]
C -->|args.load == 0| D[加载原生 Torch 权重]
C -->|args.load == 1| E[加载 Transformers 模型]
D --> F[加载 MiniMindLM 配置]
F --> G[加载权重文件]
G -->|args.lora_name != 'None'| H[应用 LoRA]
E --> I[加载预训练模型和分词器]
H --> J[返回模型和分词器]
I --> J[返回模型和分词器]
J --> K[获取 Prompt 数据]
K -->|args.model_mode == 0| L[加载预训练 Prompt 数据]
K -->|args.model_mode != 0| M[加载对话 Prompt 数据]
M -->|args.lora_name != 'None'| N[加载特定领域 Prompt 数据]
L --> O[选择测试模式]
M --> O
N --> O
O -->|自动测试| P[循环生成回答]
O -->|手动输入| Q[等待用户输入]
P --> R[生成新 Prompt]
Q --> R
R --> S[模型生成回答]
S --> T[输出回答]
T -->|继续测试| O
T -->|结束| U[退出程序]
%% 样式定义
style A fill:#f9f,stroke:#333,stroke-width:2px
style B fill:#bbf,stroke:#333,stroke-width:2px
style C fill:#bbf,stroke:#333,stroke-width:2px
style D fill:#bfb,stroke:#333,stroke-width:2px
style E fill:#bfb,stroke:#333,stroke-width:2px
style F fill:#ff9,stroke:#333,stroke-width:2px
style G fill:#ff9,stroke:#333,stroke-width:2px
style H fill:#ff9,stroke:#333,stroke-width:2px
style I fill:#ff9,stroke:#333,stroke-width:2px
style J fill:#bbf,stroke:#333,stroke-width:2px
style K fill:#bbf,stroke:#333,stroke-width:2px
style L fill:#bfb,stroke:#333,stroke-width:2px
style M fill:#bfb,stroke:#333,stroke-width:2px
style N fill:#bfb,stroke:#333,stroke-width:2px
style O fill:#bbf,stroke:#333,stroke-width:2px
style P fill:#ff9,stroke:#333,stroke-width:2px
style Q fill:#ff9,stroke:#333,stroke-width:2px
style R fill:#ff9,stroke:#333,stroke-width:2px
style S fill:#ff9,stroke:#333,stroke-width:2px
style T fill:#bbf,stroke:#333,stroke-width:2px
style U fill:#f99,stroke:#333,stroke-width:2px
开始训练
数据下载
从下文提供的数据集下载链接 下载需要的数据文件放到./dataset目录下
注:数据集须知 默认推荐下载pretrain_hq.jsonl + sft_mini_512.jsonl最快速度复现Zero聊天模型。
数据文件可自由选择,下文提供了多种搭配方案,可根据自己手头的训练需求和GPU资源进行适当组合。
预训练(学知识)/train_pretrain.py
python train_pretrain.py
%% Mermaid 流程图
%% 美化样式:使用主题和节点样式
flowchart TD
A[启动程序] --> B[解析命令行参数]
B --> C[初始化配置和环境]
C -->|DDP 模式| D[初始化分布式训练]
C -->|非 DDP 模式| E[跳过分布式初始化]
D --> F[初始化模型和分词器]
E --> F
F --> G[加载预训练数据集]
G --> H[初始化优化器和梯度缩放器]
H --> I[开始训练循环]
I --> J[训练单个 epoch]
J --> K[更新学习率]
K --> L[计算损失并反向传播]
L --> M[梯度裁剪和优化器更新]
M -->|保存间隔| N[保存模型检查点]
M -->|日志间隔| O[记录日志]
N --> P[继续训练下一步]
O --> P
P -->|完成所有 epoch| Q[结束训练]
%% 样式定义
style A fill:#f9f,stroke:#333,stroke-width:2px
style B fill:#bbf,stroke:#333,stroke-width:2px
style C fill:#bfb,stroke:#333,stroke-width:2px
style D fill:#ff9,stroke:#333,stroke-width:2px
style E fill:#ff9,stroke:#333,stroke-width:2px
style F fill:#bbf,stroke:#333,stroke-width:2px
style G fill:#bfb,stroke:#333,stroke-width:2px
style H fill:#ff9,stroke:#333,stroke-width:2px
style I fill:#bbf,stroke:#333,stroke-width:2px
style J fill:#bfb,stroke:#333,stroke-width:2px
style K fill:#ff9,stroke:#333,stroke-width:2px
style L fill:#ff9,stroke:#333,stroke-width:2px
style M fill:#ff9,stroke:#333,stroke-width:2px
style N fill:#f99,stroke:#333,stroke-width:2px
style O fill:#f99,stroke:#333,stroke-width:2px
style P fill:#bbf,stroke:#333,stroke-width:2px
style Q fill:#f9f,stroke:#333,stroke-width:2px
监督微调(学对话方式)/train_full_sft.py
python train_full_sft.py
%% Mermaid 样式美化
graph TD
A[程序开始] --> B[解析命令行参数]
B --> C[初始化配置]
C --> D[检查是否为DDP模式]
D -->|是| E[初始化分布式模式]
D -->|否| F[设置设备为单GPU或CPU]
E --> F
F --> G[初始化WandB日志记录]
G --> H[加载模型和分词器]
H --> I[加载训练数据集]
I --> J[初始化优化器和梯度缩放器]
J --> K[开始训练循环]
subgraph Training Loop
direction TB
K --> L[按Epoch循环]
L --> M[按Step循环]
M --> N[计算学习率]
N --> O[前向传播]
O --> P[计算损失]
P --> Q[反向传播]
Q --> R[梯度裁剪和优化器更新]
R --> S[记录日志]
S --> T[保存模型检查点]
T --> M
end
K --> U[程序结束]
%% 样式定义
classDef startEnd fill:#f9f,stroke:#333,stroke-width:2px;
classDef process fill:#bbf,stroke:#333,stroke-width:2px;
classDef decision fill:#f96,stroke:#333,stroke-width:2px;
class A,U startEnd;
class B,C,D,E,F,G,H,I,J,K,L,M,N,O,P,Q,R,S,T process;
class D decision;
transformers
transformers
提供了从模型加载、数据处理到训练和推理的全流程工具,涵盖了 NLP 和多模态任务的核心需求。通过模块化设计,用户可以快速加载预训练模型并应用于各种任务。
%% Mermaid 流程图
%% 美化样式:使用主题和节点样式
flowchart TD
A[初始化 transformers 模块] --> B[检查依赖项]
B -->|依赖项可用| C[定义 _import_structure]
B -->|依赖项不可用| D[加载 dummy 对象]
C --> E[导入基础工具类]
E --> F[定义模型相关模块]
F --> G[定义管道相关模块]
G --> H[定义优化器和训练器]
H --> I[完成模块初始化]
%% 关键解释
subgraph 依赖检查
B1[检查 is_torch_available] --> B2[检查 is_tf_available]
B2 --> B3[检查 is_flax_available]
B3 --> B4[检查其他依赖项]
end
subgraph 模块定义
F1[定义模型子模块] --> F2[定义模型配置]
F2 --> F3[定义模型预训练类]
G1[定义管道工具] --> G2[定义管道类型]
H1[定义优化器工具] --> H2[定义训练器工具]
end
%% 样式定义
style A fill:#f9f,stroke:#333,stroke-width:2px
style B fill:#bbf,stroke:#333,stroke-width:2px
style C fill:#bfb,stroke:#333,stroke-width:2px
style D fill:#f99,stroke:#333,stroke-width:2px
style E fill:#ff9,stroke:#333,stroke-width:2px
style F fill:#bbf,stroke:#333,stroke-width:2px
style G fill:#bfb,stroke:#333,stroke-width:2px
style H fill:#ff9,stroke:#333,stroke-width:2px
style I fill:#bbf,stroke:#333,stroke-width:2px
style B1 fill:#bbf,stroke:#333,stroke-width:1px
style B2 fill:#bbf,stroke:#333,stroke-width:1px
style B3 fill:#bbf,stroke:#333,stroke-width:1px
style B4 fill:#bbf,stroke:#333,stroke-width:1px
style F1 fill:#bfb,stroke:#333,stroke-width:1px
style F2 fill:#bfb,stroke:#333,stroke-width:1px
style F3 fill:#bfb,stroke:#333,stroke-width:1px
style G1 fill:#ff9,stroke:#333,stroke-width:1px
style G2 fill:#ff9,stroke:#333,stroke-width:1px
style H1 fill:#ff9,stroke:#333,stroke-width:1px
style H2 fill:#ff9,stroke:#333,stroke-width:1px
以下是 transformers
的类结构及其作用的简要描述,按照模块分类:
1. 核心模块
- PretrainedConfig
用于存储模型的配置参数,支持从预训练模型加载配置。 - PreTrainedModel
所有模型的基类,提供加载、保存和推理的通用方法。 - AutoModel 系列
自动化加载模型的工具类,根据模型名称自动加载对应的模型(如 AutoModelForCausalLM、AutoModelForSequenceClassification 等)。
2. 模型模块
- BertModel
BERT 模型的实现,支持文本分类、问答等任务。 - GPT2Model
GPT-2 模型的实现,主要用于文本生成任务。 - T5Model
T5 模型的实现,支持序列到序列任务(如翻译、摘要生成)。 - RobertaModel
RoBERTa 模型的实现,BERT 的改进版本,适用于分类和问答任务。 - DistilBertModel
DistilBERT 模型的实现,BERT 的轻量级版本,适用于推理速度要求较高的场景。 - WhisperModel
Whisper 模型的实现,专注于语音到文本的转换任务。
3. 数据处理模块
- AutoTokenizer
自动化加载分词器的工具类,根据模型名称加载对应的分词器。 - PreTrainedTokenizer
分词器的基类,提供文本标记化、解码等功能。 - DataCollator 系列
数据整理工具,用于将数据批处理为模型输入格式(如 DataCollatorForLanguageModeling、DataCollatorWithPadding)。
4. 管道模块
- Pipeline
高层次的任务接口,支持快速执行任务(如文本生成、分类、翻译等)。 - TextGenerationPipeline
用于文本生成任务的管道。 - TranslationPipeline
用于翻译任务的管道。 - ImageClassificationPipeline
用于图像分类任务的管道。
5. 优化与训练模块
- Trainer
通用训练器,支持模型训练、评估和保存。 - TrainingArguments
用于配置训练参数(如学习率、批量大小等)。 - AdamW
优化器,适用于 Transformer 模型,支持权重衰减。 - get_linear_schedule_with_warmup
学习率调度器,支持线性预热。
6. 特殊功能模块
- GenerationMixin
提供生成任务的通用方法(如 beam search、top-k sampling)。 - SequenceFeatureExtractor
序列特征提取工具,用于处理输入数据。 - FeatureExtractionMixin
特征提取的基类,支持从模型中提取中间层特征。
7. 依赖检查模块
- is_torch_available
检查是否安装了 PyTorch。 - is_tf_available
检查是否安装了 TensorFlow。 - is_flax_available
检查是否安装了 Flax。