【附Python源码】使用minGPT训练自己的小型GPT语言模型

【附Python源码】使用minGPT训练自己的小型GPT语言模型

近年来,大语言模型的发展令人瞩目,但对于许多开发者而言,这些动辄数十亿参数的模型更像是一个黑盒。理解其内部机制、掌握从头构建语言模型的能力,对于深入理解 Transformer 架构具有重要意义。

本项目基于 Andrej Karpathy 的 minGPT 项目,实现从零开始训练一个字符级的中文语言模型。选择字符级而非词级别,一方面可以简化词汇表构建的复杂度,另一方面也能更直观地观察模型对字符序列的学习过程。

本项目将Andrej Karpathy 的 minGPT统一组织成Jupyter notebook源码,方便调试和学习!

源码地址:github.com/anjuxi/mini...

项目概述

本项目采用《西游记》作为训练语料,构建一个能够生成古典中文风格的 GPT 模型。模型参数量控制在数百万级别,在个人 GPU 上即可完成训练。整个实现包含数据预处理、模型定义、训练流程和文本生成四个核心模块。

数据预处理

字符级编码方案

与基于子词(Subword)的分词方案不同,字符级模型将每个字符视为独立的 token。这种方案的优势在于:

  1. 词汇表规模可控,通常为数千个字符
  2. 无需处理 OOV(Out-of-Vocabulary)问题
  3. 模型可以学习到字符级别的组合规律
python 复制代码
class CharDataset(Dataset):# 字符级数据集类。
    def __init__(self, data, block_size=128):
        self.block_size = block_size;
​
        # 获取所有唯一字符并排序。
        chars = sorted(list(set(data)));
        data_size, vocab_size = len(data), len(chars);
​
        print(f"数据集统计:");
        print(f"    总字符数: {data_size} !");
        print(f"    唯一字符数: {vocab_size} !");
​
        # 创建字符到索引、索引到字符的映射。
        self.stoi = {ch: i for i, ch in enumerate(chars)};
        self.itos = {i: ch for i, ch in enumerate(chars)};
        self.vocab_size = vocab_size;
        self.data = data;
​
    def get_vocab_size(self):
        return self.vocab_size;
​
    def get_block_size(self):
        return self.block_size;
​
    def __len__(self):
        return len(self.data) - self.block_size;
​
    def __getitem__(self, idx):
        # 截取(block_size + 1)长度的字符块。
        chunk = self.data[idx:idx + self.block_size + 1];
        # 将字符转换为索引。
        dix = [self.stoi[s] for s in chunk];
        # x是输入序列,y是目标序列(x向右偏移一位)。
        x = torch.tensor(dix[:-1], dtype=torch.long);
        y = torch.tensor(dix[1:], dtype=torch.long);
        return x, y;

__getitem__ 方法的核心在于构造自监督学习的训练样本:给定前 N 个字符,预测第 N+1 个字符。这种滑动窗口的采样方式,使得数据集的每个位置都能产生一个训练样本。

模型架构

1. 激活函数:GELU

Transformer 的前馈网络通常采用 GELU(Gaussian Error Linear Unit)作为激活函数。与 ReLU 相比,GELU 在负数区域具有平滑的梯度,有助于深层网络的训练稳定。

python 复制代码
class NewGELU(nn.Module):
    """
    GELU激活函数的实现,用于Transformer中的前馈网络。
    相比ReLU,GELU在负数区域有更平滑的梯度。
    """
    def forward(self, x):
        
        temp = 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))));
        
        return temp;

该实现采用了 tanh 近似形式,计算效率优于原始的高斯积分定义。

2. 因果自注意力机制

自注意力机制是 Transformer 的核心。因果(Causal)约束确保模型在预测当前位置时,只能依赖已生成的历史信息,这是自回归语言模型的基本要求。

python 复制代码
from ast import Is
​
​
class CausalSelfAttention(nn.Module):
    """
    多头因果自注意力层,是Transformer的核心组件。
    "因果"意味着每个位置只能关注到它之前的位置,确保自回归生成。
    """
    def __init__(self, config):
        super().__init__();
        assert config.n_embd % config.n_head == 0;
​
        # 线性投影,同时生成Q、K、V。
        self.c_attn = nn.Linear(config.n_embd, 3*config.n_embd);
        # 输出投影。
        self.c_proj = nn.Linear(config.n_embd, config.n_embd);
        # Dropout正则化。
        self.attn_dropout = nn.Dropout(config.dropout);
        self.resid_dropout = nn.Dropout(config.dropout);
​
        # 注册因果掩码,防止关注到未来的token。
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size));
        self.n_head = config.n_head;
        self.n_embd = config.n_embd;
​
    def forward(self, x):
​
        B, T, C = x.size(); # B句子数量、T句子长度、C句子特征维度。
    
        # 计算Q、K、V。
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2);
​
        # 调整维度为多头形式。
        k = k.view(B, T, self.n_head, C//self.n_head).transpose(1, 2);
        q = q.view(B, T, self.n_head, C//self.n_head).transpose(1, 2);
        v = v.view(B, T, self.n_head, C//self.n_head).transpose(1, 2);
​
        # 计算注意力分数。
        # k转置:
        kt = k.transpose(-2, -1);
​
        # Q*K^T:
        att = q@kt;
​
        # 缩放:
        att = att*(1.0/math.sqrt(k.size(-1)));
​
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'));
​
        att = F.softmax(att, dim=-1);
​
        att = self.attn_dropout(att);
​
        # 加权求和并重组。
        y = att @ v;
​
        y = y.transpose(1, 2).contiguous().view(B, T, C);
        
        y = self.resid_dropout(self.c_proj(y));
        
        return y;

注意力计算遵循标准的缩放点积公式:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V </math>Attention(Q,K,V)=softmax(dk QKT)V

其中缩放因子
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> d k \sqrt{d_k} </math>dk

防止点积结果过大导致 softmax 梯度消失。

3. Transformer 块

标准的 Transformer 块由注意力子层和前馈子层组成,采用 Pre-LN(Pre-Layer Normalization)架构,即在子层输入前进行归一化。

python 复制代码
class Block(nn.Module):
    """
    标准的Transformer块,包含注意力层和前馈网络。
    使用预归一化(Pre-LN)架构。
    """
    def __init__(self, config):
        super().__init__();
        self.ln_1 = nn.LayerNorm(config.n_embd);
        self.attn = CausalSelfAttention(config);
        self.ln_2 = nn.LayerNorm(config.n_embd);
        self.mlp = nn.ModuleDict(dict(
            c_fc=nn.Linear(config.n_embd, 4*config.n_embd),
            act=NewGELU(),
            c_proj=nn.Linear(4*config.n_embd, config.n_embd),
            dropout=nn.Dropout(config.dropout),
        ));
        m = self.mlp;
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))));
​
    def forward(self, x):
        temp_1 = x;
​
        x = self.ln_1(x);
        
        x = self.attn(x);
        
        x = temp_1 + x;# 残差连接。temp_1是原来的x。
        # 第二个子层:
        temp_2 = x;
        
        x = self.ln_2(x);
        
        x = self.mlpf(x);
        
        x = temp_2 + x;
        
        return x;

前馈网络的中间维度通常设为输入维度的 4 倍,这是原始 Transformer 论文中的标准配置。

4. GPT 模型主体

python 复制代码
class GPT(nn.Module):
    """
    GPT语言模型的完整实现。
    包含词嵌入、位置嵌入、Transformer块堆叠和语言模型头。
    """
​
    def __init__(self, config):
        super().__init__();
        assert config.vocab_size is not None;
        assert config.block_size is not None;
        self.block_size = config.block_size;
​
        # Transformer主体。
        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            wpe=nn.Embedding(config.block_size, config.n_embd),
            drop=nn.Dropout(config.dropout),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=nn.LayerNorm(config.n_embd),
        ));
        # 语言模型头,将隐藏状态映射到词汇表。
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False);
​
        # 初始化权重。
        self.apply(self._init_weights);
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer));
​
        # 报告参数数量。
        n_params = sum(p.numel() for p in self.transformer.parameters());
        print(f"模型参数数量: {n_params / 1e6:.2f}M");
​
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02);
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias);
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02);
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias);
            torch.nn.init.ones_(module.weight);
​
    def configure_optimizers(self):
        """
        配置优化器,对偏置和归一化参数进行特殊处理。
        """
        decay = set();
        no_decay = set();
        whitelist_weight_modules = (torch.nn.Linear,);
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding);
​
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn;
​
                if pn.endswith('bias'):
                    no_decay.add(fpn);
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn);
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn);
​
        decay.remove('lm_head.weight');
​
        param_dict = {pn: p for pn, p in self.named_parameters()};
        inter_params = decay & no_decay;
        union_params = decay | no_decay;
​
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.1},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ];
​
        optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95));
        return optimizer;
​
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        """
        自回归生成文本。
        idx: 初始上下文,形状为(B, T)。
        max_new_tokens: 要生成的新token数量。
        temperature: 控制采样的随机性。
        do_sample: 是否采样,False则使用贪婪解码。
        top_k: 只从概率最高的k个token中采样。
        """
        for _ in range(max_new_tokens):
            # 截取到block_size。
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:];
            logits, _ = self(idx_cond);
            logits = logits[:, -1, :] / temperature;
​
            # Top-k过滤。
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)));
                logits[logits < v[:, [-1]]] = -float('Inf');
​
            # 计算概率。
            probs = F.softmax(logits, dim=-1);
​
            # 采样或贪婪选择。
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1);
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1);
​
            idx = torch.cat((idx, idx_next), dim=1);
​
        return idx;
        
    def forward(self, idx, targets=None):
        torch.set_printoptions(sci_mode=False);
​
        device = idx.device;
        b, t = idx.size();
        assert t <= self.block_size, f"序列长度{t}超过最大长度{self.block_size}";
​
        # 词嵌入 + 位置嵌入。
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0);
​
        tok_emb = self.transformer.wte(idx);
        tok_emb = self.transformer.wte(idx);
        
        pos_emb = self.transformer.wpe(pos);
​
        x = tok_emb + pos_emb;
        
        x = self.transformer.drop(x);
​
        # 通过Transformer块。
        for block in self.transformer.h:
            x = block(x);
        x = self.transformer.ln_f(x);
​
        # 计算语言模型损失。
        logits = self.lm_head(x);
        loss = None;
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1);
​
        return logits, loss;

模型包含词嵌入(wte)和位置嵌入(wpe)两层可学习参数。位置嵌入采用绝对位置编码,最大序列长度由 block_size 限定。

5. 文本生成

生成过程采用自回归方式,每次预测下一个字符,并将其拼接到输入序列中继续预测。

python 复制代码
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        """
        自回归生成文本。
        idx: 初始上下文,形状为(B, T)。
        max_new_tokens: 要生成的新token数量。
        temperature: 控制采样的随机性。
        do_sample: 是否采样,False则使用贪婪解码。
        top_k: 只从概率最高的k个token中采样。
        """
        for _ in range(max_new_tokens):
            # 截取到block_size。
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:];
            logits, _ = self(idx_cond);
            logits = logits[:, -1, :] / temperature;
​
            # Top-k过滤。
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)));
                logits[logits < v[:, [-1]]] = -float('Inf');
​
            # 计算概率。
            probs = F.softmax(logits, dim=-1);
​
            # 采样或贪婪选择。
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1);
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1);
​
            idx = torch.cat((idx, idx_next), dim=1);
​
        return idx;

Temperature 参数控制采样的随机性:值越小,分布越尖锐,生成结果越确定;值越大,分布越平缓,生成结果越多样。Top-k 采样则限制了候选 token 的范围,避免选择概率极低的结果。

训练流程

优化器配置

采用 AdamW 优化器,并对参数进行分组处理:对权重矩阵应用权重衰减(weight decay),对偏置和归一化参数则不应用。

python 复制代码
    def configure_optimizers(self):
        """
        配置优化器,对偏置和归一化参数进行特殊处理。
        """
        decay = set();
        no_decay = set();
        whitelist_weight_modules = (torch.nn.Linear,);
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding);
​
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = f"{mn}.{pn}" if mn else pn;
​
                if pn.endswith('bias'):
                    no_decay.add(fpn);
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn);
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn);
​
        decay.remove('lm_head.weight');
​
        param_dict = {pn: p for pn, p in self.named_parameters()};
        inter_params = decay & no_decay;
        union_params = decay | no_decay;
​
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.1},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ];
​
        optimizer = torch.optim.AdamW(optim_groups, lr=LEARNING_RATE, betas=(0.9, 0.95));
        return optimizer;

训练循环

python 复制代码
class Trainer:
    def __init__(self, config, model, train_dataset):
        self.config = config;
        self.model = model;
        self.train_dataset = train_dataset;
        self.callbacks = defaultdict(list);
        self.device = device;
​
        self.iter_num = 0;
        self.iter_time = 0.0;
        self.iter_dt = 0.0;
​
    def add_callback(self, onevent, callback):
        self.callbacks[onevent].append(callback);
​
    def set_callback(self, onevent, callback):
        self.callbacks[onevent] = [callback];
​
    def trigger_callbacks(self, onevent):
        for callback in self.callbacks.get(onevent, []):
            callback(self);
    def run(self):
        model, config = self.model, self.config;
​
        # 设置优化器。
        optimizer = model.configure_optimizers();
​
        # 设置数据加载器。
        train_loader = DataLoader(
            self.train_dataset,
            sampler=torch.utils.data.RandomSampler(
                self.train_dataset, replacement=True, num_samples=int(1e10)
            ),
            shuffle=False,
            pin_memory=True,
            batch_size=config.batch_size,
            num_workers=0,
        );
​
        model.train();
        self.iter_num = 0;
        self.iter_time = time.time();
        data_iter = iter(train_loader);
​
        # 记录损失历史。
        self.loss_history = [];
​
        while True:
            try:
                batch = next(data_iter);
            except StopIteration:
                data_iter = iter(train_loader);
                batch = next(data_iter);
​
            batch = [t.to(self.device) for t in batch];
            x, y = batch;
​
            # 前向传播。
            logits, loss = model(x, y);
​
            # 反向传播和优化。
            model.zero_grad(set_to_none=True);
            loss.backward();
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip);
            optimizer.step();
​
            self.loss = loss;
            self.trigger_callbacks('on_batch_end');
            self.iter_num += 1;
​
            tnow = time.time();
            self.iter_dt = tnow - self.iter_time;
            self.iter_time = tnow;
​
            if self.iter_num % EVAL_INTERVAL == 0:
                print(f"step {self.iter_num}: loss = {loss.item():.6f}, time = {self.iter_dt:.3f}s")
​
​
            # 记录损失。
            self.loss_history.append(loss.item());
​
            # 终止条件。
            if MAX_ITERS is not None and self.iter_num >= MAX_ITERS:
                break;
​
        return self.loss_history;

训练过程中采用梯度裁剪(gradient clipping),将梯度范数限制在阈值以内,这对训练深层 Transformer 尤为重要。

实验配置与结果

本项目采用的超参数如下:

参数 说明
n_layer 2 Transformer 层数
n_head 2 注意力头数
n_embd 512 嵌入维度
block_size 128 上下文长度
batch_size 64 批量大小
learning_rate 15e-4 初始学习率
dropout 0.1 Dropout 比率

模型总参数量约为 8.52M,在单张消费级 GPU 上训练 1000 轮约需数分钟。训练完成后,模型能够生成具有一定连贯性的古典中文段落。以"孙"字为提示,模型可生成:

python 复制代码
生成结果: 孙行者道:"你是个'金箍儿'?"
   那呆子听得说,即忙纵筋斗云,直至殿上,把三藏与八戒、沙僧,径至前,
只见那门里走出一个大小妖,一齐下,就是一个大小妖,把那些妖精,一个个
个个个
喧哗,一齐上前,

虽然语义完整性和长程一致性仍有提升空间,但模型已展现出对字符组合和句式结构的学习能力。

总结

通过本项目,可以深入理解 GPT 模型的核心机制:

  1. 字符级建模简化了词汇表构建,但增加了序列长度
  2. 因果自注意力通过掩码机制实现自回归生成
  3. 残差连接和层归一化保证了深层网络的训练稳定
  4. 温度采样和 Top-k 策略平衡了生成的多样性和质量

对于希望深入理解 Transformer 架构的开发者,建议在此基础上尝试:增大模型规模、引入学习率衰减策略、或采用 Byte Pair Encoding(BPE)等子词分词方案。

源码地址:github.com/anjuxi/mini...

相关推荐
QuZero1 小时前
StampedLock Mechanism
java·算法
云泽8081 小时前
二叉树高阶笔试算法题精讲(二):非递归遍历与序列构造全解析
c++·算法·面试
小O的算法实验室2 小时前
2026年ESWA,基于固定机巢的无人机输电杆塔、变电站与配电杆混合巡检任务分配与路径规划,深度解析+性能实测
算法·论文复现·智能算法·智能算法改进
sali-tec4 小时前
C# 基于OpenCv的视觉工作流-章60-点点距离
图像处理·人工智能·opencv·算法·计算机视觉
nlpming5 小时前
OpenCode Skills 文档
算法
无限进步_5 小时前
二叉搜索树完全解析:从概念到实现与应用场景
c语言·开发语言·数据结构·c++·算法·github·visual studio
05候补工程师5 小时前
深度解构 ROS 2:如何手动调通 Nav2 A* 路径规划引擎
linux·人工智能·经验分享·算法·机器人
上弦月-编程5 小时前
【C语言逻辑题】谋杀案凶手是谁?——经典矛盾推理题详解
算法
天若有情6736 小时前
逆向玩家狂喜!用C++野生写法一键破解线性加密(不规范但巨好用)
开发语言·c++·算法