LM实现教程:基于 nanochat项目 从零开始理解大语言模型

LLM实现教程:从零开始理解大语言模型

基于 nanochat 仓库的深入解析

引言

nanochat 是一个完整的LLM实现,成本约100美元即可训练的ChatGPT类模型。本教程将带你深入理解:

  • LLM的基本原理
  • Transformer架构的核心实现
  • 完整的训练流程(预训练、微调、强化学习)
  • 各种优化技术和分布式训练

为何选择nanochat?

  1. 代码精简:约8K行代码,易于理解
  2. 端到端:从分词到部署的完整流程
  3. 约100美元:成本低
  4. 可配置性高:代码结构清晰,易于修改

LLM基础知识

什么是LLM?

大语言模型(Large Language Model, LLM)是一类基于Transformer架构的深度学习模型,通过学习大量文本数据来理解和生成自然语言。

核心概念

1. Token与Tokenization(分词)

为什么需要分词?

神经网络无法直接处理文本,需要将文本转换为数字序列。

BPE(Byte Pair Encoding) 是主流方法:

1:30:nanochat/tokenizer.py 复制代码
"""
BPE Tokenizer in the style of GPT-4.

Two implementations are available:
1) HuggingFace Tokenizer that can do both training and inference but is really confusing
2) Our own RustBPE Tokenizer for training and tiktoken for efficient inference
"""

import os
import copy
from functools import lru_cache

SPECIAL_TOKENS = [
    # every document begins with the Beginning of Sequence (BOS) token that delimits documents
    "<|bos|>",
    # tokens below are only used during finetuning to render Conversations into token ids
    "<|user_start|>", # user messages
    "<|user_end|>",
    "<|assistant_start|>", # assistant messages
    "<|assistant_end|>",
    "<|python_start|>", # assistant invokes python REPL tool
    "<|python_end|>",
    "<|output_start|>", # python REPL outputs back to assistant
    "<|output_end|>",
]

工作原理

  1. 从单个字符开始
  2. 统计最常出现的相邻符号对
  3. 合并成新符号
  4. 重复上述过程直到达到目标词汇表大小

nanochat的实现

155:182:nanochat/tokenizer.py 复制代码
class RustBPETokenizer:
    """Light wrapper around tiktoken (for efficient inference) but train with rustbpe"""

    def __init__(self, enc, bos_token):
        self.enc = enc
        self.bos_token_id = self.encode_special(bos_token)

    @classmethod
    def train_from_iterator(cls, text_iterator, vocab_size):
        # 1) train using rustbpe
        tokenizer = rustbpe.Tokenizer()
        # the special tokens are inserted later in __init__, we don't train them here
        vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)
        assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
        tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
        # 2) construct the associated tiktoken encoding for inference
        pattern = tokenizer.get_pattern()
        mergeable_ranks_list = tokenizer.get_mergeable_ranks()
        mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
        tokens_offset = len(mergeable_ranks)
        special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
        enc = tiktoken.Encoding(
            name="rustbpe",
            pat_str=pattern,
            mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
            special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
        )
        return cls(enc, "<|bos|>")
2. Transformer架构

核心组件

  1. Embeddings(嵌入层):将token id转换为向量
  2. Attention(注意力机制):让模型关注相关的上下文
  3. Feed-Forward(前馈网络):非线性变换
  4. Layer Norm(层归一化):稳定训练
138:156:nanochat/gpt.py 复制代码
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        # To support meta device initialization, we init the rotary embeddings here, but it's fake
        # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
        # so let's just over-compute them, but assert fail if we ever reach that amount.
        # In the future we can dynamically grow the cache, for now it's fine.
        self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
        head_dim = config.n_embd // config.n_head
        cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
        self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
        self.register_buffer("sin", sin, persistent=False)
3. 注意力机制(Attention)

**自注意力(Self-Attention)**让模型理解序列内部关系:

51:110:nanochat/gpt.py 复制代码
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x, cos_sin, kv_cache):
        B, T, C = x.size()

        # Project the input to get queries, keys, and values
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

        # Apply Rotary Embeddings to queries and keys to get relative positional encoding
        cos, sin = cos_sin
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
        q, k = norm(q), norm(k) # QK norm
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)

        # Apply KV cache: insert current k,v into cache, get the full view so far
        if kv_cache is not None:
            k, v = kv_cache.insert_kv(self.layer_idx, k, v)
        Tq = q.size(2) # number of queries in this forward pass
        Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)

        # Attention: queries attend to keys/values autoregressively. A few cases to handle:
        enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
        if kv_cache is None or Tq == Tk:
            # During training (no KV cache), attend as usual with causal attention
            # And even if there is KV cache, we can still use this simple version when Tq == Tk
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
        elif Tq == 1:
            # During inference but with a single query in this forward pass:
            # The query has to attend to all the keys/values in the cache
            y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
        else:
            # During inference AND we have a chunk of queries in this forward pass:
            # First, each query attends to all the cached keys/values (i.e. full prefix)
            attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device) # True = keep, False = mask
            prefix_len = Tk - Tq
            if prefix_len > 0: # can't be negative but could be zero
                attn_mask[:, :prefix_len] = True
            # Then, causal attention within this chunk
            attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)

        # Re-assemble the heads side by side and project back to residual stream
        y = y.transpose(1, 2).contiguous().view(B, T, -1)
        y = self.c_proj(y)
        return y

关键点

  • Q(Query):查询向量,寻找信息
  • K(Key):键向量,提供信息位置
  • V(Value):值向量,实际信息内容
  • 因果掩码(Causal Mask):确保只能看到当前位置之前的token
4. 位置编码

Rotary Position Embedding (RoPE)

41:49:nanochat/gpt.py 复制代码
def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
    y1 = x1 * cos + x2 * sin # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    out = torch.cat([y1, y2], 3) # re-assemble
    out = out.to(x.dtype) # ensure input/output dtypes match
    return out

RoPE通过旋转向量的方式编码位置信息,比传统的绝对位置嵌入更优雅。


nanochat架构概览

整体流程

复制代码
1. 数据下载 → 2. 训练分词器 → 3. 预训练 → 4. 中训练 → 5. SFT微调 → 6. 评估/部署

目录结构

复制代码
nanochat/
├── nanochat/              # 核心代码
│   ├── gpt.py            # GPT模型定义
│   ├── engine.py         # 推理引擎(KV Cache)
│   ├── tokenizer.py      # BPE分词器
│   ├── dataloader.py     # 数据加载
│   ├── adamw.py          # AdamW优化器
│   ├── muon.py           # Muon优化器
│   └── execution.py      # Python代码执行工具
├── scripts/              # 训练脚本
│   ├── base_train.py    # 预训练
│   ├── mid_train.py     # 中训练
│   ├── chat_sft.py      # SFT训练
│   └── chat_rl.py       # RL训练
└── tasks/               # 评估任务

核心组件详解

1. GPT模型

模型配置

26:34:nanochat/gpt.py 复制代码
@dataclass
class GPTConfig:
    sequence_len: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 6 # number of query heads
    n_kv_head: int = 6 # number of key/value heads (MQA)
    n_embd: int = 768

Transformer Block

126:135:nanochat/gpt.py 复制代码
class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)

    def forward(self, x, cos_sin, kv_cache):
        x = x + self.attn(norm(x), cos_sin, kv_cache)
        x = x + self.mlp(norm(x))
        return x

前馈网络

113:123:nanochat/gpt.py 复制代码
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x

关键特性

  • 无偏置:所有线性层不使用偏置
  • relu^2 激活relu(x)²,较标准ReLU表现更好
  • 残差连接x = x + f(x)
  • 层归一化:使用RMSNorm
  • 嵌入与输出头不共享权重 :untied weights(wtelm_head 独立)

2. 推理引擎(KV Cache)

为什么需要KV Cache?

在推理时,之前计算的key-value可以缓存,避免重复计算:

82:150:nanochat/engine.py 复制代码
class KVCache:
    """
    Works hand-in-hand with the GPT model to maintain the KV cache.
    Note that the .pos advances automatically after the last layer of the Transformer inserts.
    """

    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
        # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
        self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_cache = None
        self.pos = 0 # current position in time in the cache

    def reset(self):
        self.pos = 0

    def get_pos(self):
        return self.pos

    def prefill(self, other):
        """
        Prefill given another KV cache. Optionally expand along batch dim.
        This is used when we do batch 1 prefill and then want to generate
        multiple samples in parallel from there.
        """
        # 1) validate the shapes
        assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
        assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
        for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
            if ix in [0, 1, 3, 5]:
                # num_layers, batch_size, num_heads, head_dim must match
                assert dim1 == dim2, f"Batch dim mismatch: {dim1} != {dim2}"
            elif ix == 2:
                # batch_size can be expanded
                assert dim1 == dim2 or dim2 == 1, f"Batch dim mismatch: {dim1} != {dim2}"
            elif ix == 4:
                # seq_len: self must be longer than other
                assert dim1 >= dim2, f"Seq len mismatch: {dim1} < {dim2}"
        # 2) initialize the cache
        dtype, device = other.kv_cache.dtype, other.kv_cache.device
        self.kv_cache = torch.empty(self.kv_shape, dtype=dtype, device=device)
        # 3) copy the data over
        self.kv_cache[:, :, :, :, :other.pos, :] = other.kv_cache
        # 4) update the pos
        self.pos = other.pos

    def insert_kv(self, layer_idx, k, v):
        # Lazy initialize the cache here because we need to know the dtype/device
        if self.kv_cache is None:
            self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
        # Insert new keys/values to the cache and return the full cache so far
        B, H, T_add, D = k.size()
        t0, t1 = self.pos, self.pos + T_add
        # Dynamically grow the cache if needed
        if t1 > self.kv_cache.size(4):
            t_needed = t1 + 1024 # as much as we need plus buffer of 1024
            t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
            current_shape = list(self.kv_cache.shape)
            current_shape[4] = t_needed
            self.kv_cache.resize_(current_shape)
        # Insert k, v into the cache
        self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
        self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
        # Return the full cached keys/values up to current position (as a view)
        key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
        value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
        # Increment pos after the last layer of the Transformer processes
        if layer_idx == self.kv_cache.size(0) - 1:
            self.pos = t1
        return key_view, value_view

3. 数据加载

分布式数据加载器

9:49:nanochat/dataloader.py 复制代码
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
    """Stream pretraining text from parquet files, tokenize, yield training batches."""
    assert split in ["train", "val"], "split must be 'train' or 'val'"
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
    # get the tokenizer and the bos token
    tokenizer = get_tokenizer()
    bos_token = tokenizer.get_bos_token_id()
    # scratch buffer holds the tokens for one iteration
    token_buffer = deque() # we stream tokens on the right and pop from the left

    # infinite iterator over document batches
    def document_batches():
        while True:
            # batch will iterate in group size of the parquet files, usually e.g. 1024 rows
            for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
                # for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
                for i in range(0, len(batch), tokenizer_batch_size):
                    yield batch[i:i+tokenizer_batch_size]
    batches = document_batches()

    batch_index = 0
    while True:
        # Accumulate enough tokens for one iteration before yielding.
        while len(token_buffer) < needed_tokens:
            doc_batch = next(batches)
            token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
            for tokens in token_lists:
                token_buffer.extend(tokens)
            batch_index += 1
        # Move tokens from the deque into the scratch buffer
        tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
        # CUDA supports memory pinning for faster transfers between CPU and GPU:
        scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
        # Create the inputs/targets as 1D tensors
        inputs_cpu = scratch[:-1].to(dtype=torch.int32)
        targets_cpu = scratch[1:]
        # Reshape to 2D and move to GPU async
        inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
        targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
        yield inputs, targets

关键点

  • 滑动窗口:将长文本切分成固定长度序列
  • 异步传输 :使用non_blocking=True加速数据传输
  • 内存固定pin_memory加速CPU到GPU传输

训练流程

阶段1:预训练(Base Training)

目标:在大规模文本上学习语言模型

核心代码

107:182:scripts/base_train.py 复制代码
# Initialize the Model
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
    model_config = GPTConfig(**model_config_kwargs)
    model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")

# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
if num_iterations > 0:
    print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
    # calculate the number of iterations from the target flops
    num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
    print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
    # calculate the number of iterations from the target param data ratio
    target_tokens = target_param_data_ratio * num_params
    num_iterations = target_tokens // total_batch_size
    print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
    raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")

# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers

# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data

关键超参数

  • Chinchilla定律:数据量 ≈ 20 × 参数量
  • 学习率:分层设置
  • 批次大小:总批次大小 = device_batch_size × world_size × grad_accum_steps

阶段2:中训练(Mid Training)

目标:引入对话格式、工具使用、选择题

混合训练数据

98:106:scripts/mid_train.py 复制代码
train_dataset = TaskMixture([
    SmolTalk(split="train"), # 460K rows of general conversations
    MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
    GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
    CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
    CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
    SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
    SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows

阶段3:监督微调(SFT)

目标:进一步优化对话能力和任务表现

84:92:scripts/chat_sft.py 复制代码
train_ds = TaskMixture([
    ARC(subset="ARC-Easy", split="train"), # 2.3K rows
    ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
    GSM8K(subset="main", split="train"), # 8K rows
    SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
    CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
    SimpleSpelling(size=300, split="train"), # 300 rows of Simple Spelling (e.g. spell the word 'apple')
    SpellingBee(size=300, split="train"), # 300 rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # 2.3K +  שחumbles + 8K + 10K + 1K + 0.3K + 0.3K = 23K rows

对话渲染

258:342:nanochat/tokenizer.py 复制代码
def render_conversation(self, conversation, max_tokens=2048):
    """
    Tokenize a single Chat conversation (which we call a "doc" or "document" here).
    Returns:
    - ids: list[int] is a list of token ids of this rendered conversation
    - mask: list[int] of same length, mask = 1 for tokens that the Assistant is expected to train on.
    """
    # ids, masks that we will return and a helper function to help build them up.
    ids, mask = [], []
    def add_tokens(token_ids, mask_val):
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        ids.extend(token_ids)
        mask.extend([mask_val] * len(token_ids))

    # sometimes the first message is a system message...
    # => just merge it with the second (user) message
    if conversation["messages"][0]["role"] == "system":
        # some conversation surgery is necessary here for now...
        conversation = copy.deepcopy(conversation) # avoid mutating the original
        messages = conversation["messages"]
        assert messages[1]["role"] == "user", "System message must be followed by a user message"
        messages[1]["content"] = messages[0]["content"] + "\n\n" + messages[1]["content"]
        messages = messages[1:]
    else:
        messages = conversation["messages"]
    assert len(messages) >= 1, f"Conversation has less than 1 message: {messages}"

    # fetch all the special tokens we need
    bos = self.get_bos_token_id()
    user_start, user_end = self.encode_special("<|user_start|>"), self.encode_special("<|user_end|>")
    assistant_start, assistant_end = self.encode_special("<|assistant_start|>"), self.encode_special("<|assistant_end|>")
    python_start, python_end = self.encode_special("<|python_start|>"), self.encode_special("<|python_end|>")
    output_start, output_end = self.encode_special("<|output_start|>"), self.encode_special("<|output_end|>")

    # now we can tokenize the conversation
    add_tokens(bos, 0)
    for i, message in enumerate(messages):

        # some sanity checking here around assumptions, to prevent footguns
        must_be_from = "user" if i % 2 == 0 else "assistant"
        assert message["role"] == must_be_from, f"Message {i} is from {message['role']} but should be from {must_be_from}"

        # content can be either a simple string or a list of parts (e.g. containing tool calls)
        content = message["content"]

        if message["role"] == "user":
            assert isinstance(content, str), "User messages are simply expected to be strings"
            value_ids = self.encode(content)
            add_tokens(user_start, 0)
            add_tokens(value_ids, 0)
            add_tokens(user_end, 0)
        elif message["role"] == "assistant":
            add_tokens(assistant_start, 0)
            if isinstance(content, str):
                # simple string => simply add the tokens
                value_ids = self.encode(content)
                add_tokens(value_ids, 1)
            elif isinstance(content, list):
                for part in content:
                    value_ids = self.encode(part["text"])
                    if part["type"] == "text":
                        # string part => simply add the tokens
                        add_tokens(value_ids, 1)
                    elif part["type"] == "python":
                        # python tool call => add the tokens inside <|python_start|> and <|python_end|>
                        add_tokens(python_start, 1)
                        add_tokens(value_ids, 1)
                        add_tokens(python_end, 1)
                    elif part["type"] == "python_output":
                        # python output => add the tokens inside <|output_start|> and <|output_end|>
                        # none of these tokens are supervised because the tokens come from Python at test time
                        add_tokens(output_start, 0)
                        add_tokens(value_ids, 0)
                        add_tokens(output_end, 0)
                    else:
                        raise ValueError(f"Unknown part type: {part['type']}")
            else:
                raise ValueError(f"Unknown content type: {type(content)}")
            add_tokens(assistant_end, 1)

    # truncate to max_tokens tokens MAX (helps prevent OOMs)
    ids = ids[:max_tokens]
    mask = mask[:max_tokens]
    return ids, mask

关键点

  • mask机制:只对assistant回复计算损失
  • 特殊token:区分用户消息、助手消息、工具调用

关键技术细节

1. 优化器

混合优化器:对不同参数使用不同优化器

213:242:nanochat/gpt.py 复制代码
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
    model_dim = self.config.n_embd
    ddp, rank, local_rank, world_size = get_dist_info()
    # Separate out all parameters into 3 groups (matrix, embedding, lm_head)
    matrix_params = list(self.transformer.h.parameters())
    embedding_params = list(self.transformer.wte.parameters())
    lm_head_params = list(self.lm_head.parameters())
    assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
    # Create the AdamW optimizer for the embedding and lm_head
    # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
    dmodel_lr_scale = (model_dim / 768) ** -0.5
    if rank == 0:
        print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
    adam_groups = [
        dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
        dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
    ]
    adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
    AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
    adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
    # Create the Muon optimizer for the linear layers
    muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
    MuonFactory = DistMuon if ddp else Muon
    muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
    # Combine them the two optimizers into one list
    optimizers = [adamw_optimizer, muon_optimizer]
    for opt in optimizers:
        for group in opt.param_groups:
            group["initial_lr"] = group["lr"]
    return optimizers

Muon优化器:针对线性层,使用Newton-Schulz正交化:

10:36:nanochat/muon.py 复制代码
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X

    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

2. 分布式训练

ZeRO-2式分片

10:76:nanochat/adamw.py 复制代码
class DistAdamW(torch.optim.Optimizer):
    """
    Distributed AdamW optimizer.
    In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
    """
    def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(param_groups, defaults)

    @torch.compile
    @torch.no_grad()
    def step(self):
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        reduce_scatter_futures: list[torch.Future] = []
        all_reduce_futures: list[torch.Future] = []
        grad_slices = []
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            for base_i in range(len(params)):
                grad = params[base_i].grad
                rank_size = grad.shape[0] // world_size
                grad_slice = torch.empty_like(grad[:rank_size])
                reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future())
                grad_slices.append(grad_slice)

        idx = 0
        for group in self.param_groups:
            beta1, beta2 = group['betas']
            eps = group['eps']
            wd = group['weight_decay']
            params = group['params']
            for base in range(len(params)):
                reduce_scatter_futures[idx].wait()
                p = params[base]
                rank_size = p.shape[0] // world_size
                p_slice = p[rank * rank_size:(rank + 1) * rank_size]
                lr = group['lr'] * getattr(p, "lr_mul", 1.0)
                state = self.state[p]
                g_slice = grad_slices[idx]
                # State init
                if not state:
                    state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
                    state['exp_avg'] = torch.zeros_like(p_slice)
                    state['exp_avg_sq'] = torch.zeros_like(p_slice)
                exp_avg = state['exp_avg']
                exp_avg_sq = state['exp_avg_sq']
                state['step'] += 1
                t = state['step']
                # weight decay
                if wd != 0:
                    eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)
                    p_slice.mul_(1 - eff_weight_decay)
                # update running averages
                exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
                # bias corrections
                bias1 = 1 - beta1 ** t
                bias2 = 1 - beta2 ** t
                # compute step
                denom = exp_avg_sq.sqrt().add_(eps)
                step_size = lr * (torch.sqrt(bias2) / bias1)
                update = exp_avg.div(denom).mul_(step_size)
                p_slice.add_(other=update, alpha=-1.0)
                idx += 1
                all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
        torch.futures.collect_all(all_reduce_futures).wait()

关键点

  • reduce_scatter:梯度平均
  • all_gather:参数同步
  • 异步通信:提高并行度

3. 工具使用(Tool Use)

Python执行工具

36:79:nanochat/engine.py 复制代码
def eval_with_timeout(formula, max_time=3):
    try:
        with timeout(max_time, formula):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", SyntaxWarning)
                return eval(formula)
    except Exception as e:
        signal.alarm(0)
        # print(f"Warning: Failed to eval {formula}, exception: {e}") # it's ok ignore wrong calculator usage
        return None

def use_calculator(expr):
    """
    Evaluate a Python expression safely.
    Supports both math expressions and string operations like .count()
    """
    # Remove commas from numbers
    expr = expr.replace(",", "")

    # Check if it's a pure math expression (old behavior)
    if all([x in "0123456789*+-/.() " for x in expr]):
        if "**" in expr:  # disallow power operator
            return None
        return eval_with_timeout(expr)

    # Check if it's a string operation we support
    # Allow: strings (single/double quotes), .count(), letters, numbers, spaces, parens
    allowed_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'\"()._ "
    if not all([x in allowed_chars for x in expr]):
        return None

    # Disallow dangerous patterns
    dangerous_patterns = ['__', 'import', 'exec', 'eval', 'compile', 'open', 'file',
                         'input', 'raw_input', 'globals', 'locals', 'vars', 'dir',
                         'getattr', 'setattr', 'delattr', 'hasattr']
    expr_lower = expr.lower()
    if any(pattern in expr_lower for pattern in dangerous_patterns):
        return None

    # Only allow .count() method for now (can expand later)
    if '.count(' not in expr:
        return None

    # Evaluate with timeout
    return eval_with_timeout(expr)

在生成中调用工具

271:287:nanochat/engine.py 复制代码
# Handle tool logic
if next_token == python_start:
    state.in_python_block = True
    state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
    state.in_python_block = False
    if state.python_expr_tokens:
        expr = self.tokenizer.decode(state.python_expr_tokens)
        result = use_calculator(expr)
        if result is not None:
            result_tokens = self.tokenizer.encode(str(result))
            state.forced_tokens.append(output_start)
            state.forced_tokens.extend(result_tokens)
            state.forced_tokens.append(output_end)
    state.python_expr_tokens = []
elif state.in_python_block:
    state.python_expr_tokens.append(next_token)

运行流程

一键运行

1:137:speedrun.sh 复制代码
#!/bin/bash

# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.

# 1) Example launch (simplest):
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh

# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR

# -----------------------------------------------------------------------------
# Python venv setup with uv

# install uv (if not already installed)
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# install the repo dependencies
uv sync --extra gpu
# activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate

# -----------------------------------------------------------------------------
# wandb setup
# If you wish to use wandb for logging (it's nice!, recommended).
# 1) Make sure to first log in to wandb, e.g. run:
#    `wandb login`
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
#    `WANDB_RUN=d26 bash speedrun.sh`
if [ -z "$WANDB_RUN" ]; then
    # by default use "dummy" : it's handled as a special case, skips logging to wandb
    WANDB_RUN=dummy
fi

# -----------------------------------------------------------------------------
# During the course of the run, we will be writing markdown reports to the report/
# directory in the base dir. This command clears it out and writes a header section
# with a bunch of system info and a timestamp that marks the start of the run.
python -m nanochat.report reset

# -----------------------------------------------------------------------------
# Tokenizer

# Install Rust / Cargo
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"

# Build the rustbpe Tokenizer
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml

# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval

# -----------------------------------------------------------------------------
# Base model (pretraining)

# Download the eval_bundle from s3 to evaluate CORE metric during training (~162MB)
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
    curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
    unzip -q eval_bundle.zip
    rm eval_bundle.zip
    mv eval_bundle $NANOCHAT_BASE_DIR
fi

# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID

# pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval

# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)

# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl

# run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid

# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)

# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft

# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"

# even better, chat with your model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web

Windows/WSL/PowerShell 提示

  • 推荐在 Linux/WSL 跑完整流程;Windows 原生适合小规模验证。
  • PowerShell 快速试跑(CPU/小模型示例):
powershell 复制代码
# 进入仓库根目录
cd E:\open_src2\nanochat

# 可选:构建 RustBPE(如需自训练分词器)
pip install maturin
maturin develop --release --manifest-path rustbpe/Cargo.toml

# 安装 Python 依赖(建议先装好 PyTorch,对应你的 CUDA/CPU 环境)
pip install -e .

# 运行一个极小训练以验证流程
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20

# 启动 Web 聊天(如已有 checkpoint,可直接推理)
python -m scripts.chat_web

技术小结

核心技术要点

  1. BPE分词:平衡压缩率与计算效率
  2. Transformer架构:自注意力 + 前馈网络
  3. RoPE位置编码:Rotary Position Embedding
  4. 混合优化:Muon(矩阵) + AdamW(嵌入)
  5. 分布式训练:ZeRO-2分片
  6. KV Cache:加速推理
  7. 工具使用:Python代码执行

训练阶段

  1. 预训练:大规模文本学习语言模型
  2. 中训练:引入对话格式和特殊能力
  3. SFT:微调对话能力
  4. RL(可选):使用奖励模型优化

关键要点

  • 数据规模:遵循Chinchilla定律(20×参数量)
  • 计算效率torch.compile加速
  • 内存优化:梯度累积、混合精度
  • 代码质量:简洁、易读、可扩展

下一步

  1. 尝试修改模型架构
  2. 调整超参数
  3. 添加新的数据集
  4. 实现新的评估任务
  5. 探索强化学习

560,988,160(约5.61亿)这个参数总数是怎么计算出来的?

这实际上是基于Transformer模型的结构参数,严格、系统地统计模型各个层、各个模块的可训练参数数量。我们知道 nanochat 的以下结构信息:

  • 层数(20层)
  • embedding维度(model_dim=1280)
  • 词表大小(vocab_size=65536)
  • 注意力头数(num_heads=10)
  • 每步token数/每层前后维度关系......
  • 通常还有前向/后向投影,MLP中间扩张维度、输出层等

可以计算:

  • Embedding层:83,886,080
  • 输出层:83,886,080
  • Transformer主体:393,216,000
  • LayerNorm少量

总计:

83,886,080 + 83,886,080 + 393,216,000 = 560,988,160

560,988,160 就是:

  • 词嵌入(输入+输出)
  • 每层 attention+MLP
  • 总共20层 + 以上加起来的参数和。
  • 如果使用权重共享(tied embedding),输出层参数不会被重复计算。
  • 这里是最经典的GPT-style Transformer参数统计方式。不同实现如FeedForward维度变化、是否有bias、RMSNorm等会略有差异。

具体来说:

python 复制代码
# 模型的词表大小,表示一共有多少个不同的token(词/字/符号等),比如65536常用于中等多语模型
vocab_size = 65536

# Transformer模型的层数(即有多少个Transformer Block,一般层数越多,能力越强)
n_layer = 20

# 隐藏层维度(也称为模型宽度,通常等于embedding/hidden state/输出层维度)
model_dim = 1280

# MLP隐藏层维度,通常是model_dim的4倍(即 "Feed Forward Network" 的expansion比例)
mlp_dim = 5120  # 通常4倍于模型维度

# 一层Transformer中所有注意力子层(多头自注意力的Q/K/V/O)对应的线性投影参数总量
# Q: Query, K: Key, V: Value, O: Output projection,各自是一个 model_dim x model_dim 的weight,共4份
attention_proj = 4 * model_dim * model_dim  # 总共4个权重矩阵的参数量

# 一层Transformer中MLP部分的参数总量
# 由两个线性层组成,输入 model_dim -> 隐藏 mlp_dim -> 输出 model_dim
# 所以有两个参数矩阵:第一个 model_dim x mlp_dim,第二个 mlp_dim x model_dim
mlp_proj = 2 * model_dim * mlp_dim

# 每层Transformer Block的参数总量 = 注意力部分 + 前馈神经网络(MLP)部分
per_layer = attention_proj + mlp_proj

# 总参数量 = token embedding参数 + 输出层参数 + 每层block参数 * 层数
# embedding参数 = vocab_size x model_dim
# 输出层参数 = vocab_size x model_dim(通常与embedding层共用一套参数,但这里各自计入一次;如有权重共享应只加一次)
# transformer主体参数 = n_layer * per_layer
total = vocab_size * model_dim * 2 + n_layer * per_layer

# 输出总参数数量,即本模型(20层,1280宽度,65536词表)的参数量
print(total)  # 输出:560988160

在上面代码中提到:MLP隐藏层维度,通常是model_dim的4倍。那么:

什么是 MLP隐藏层?

MLP(全称:Multi-Layer Perceptron,多层感知机),在Transformer中的每个Block(层)都有一个MLP子模块,也称为前馈神经网络(Feed Forward Network, FFN)。

对于Transformer而言,它的结构通常是:

  1. 多头自注意力(Multi-Head Self Attention, MHSA)
  2. MLP/FFN模块

MLP一般包括2个全连接(Linear/Dense)层和一个激活函数(通常是GELU/Relu等):

  • 输入维度为 model_dim(隐状态维度,比如1280)
  • 经过第一个全连接层,升维到 mlp_hidden_dim
  • 经过激活函数(如GELU)
  • 再经过第二个全连接层,还原回原维度(mlp_hidden_dim → model_dim)

伪代码结构:

python 复制代码
def MLP(x):
    x = Linear1(x)  # model_dim → mlp_hidden_dim
    x = activation(x)
    x = Linear2(x)  # mlp_hidden_dim → model_dim
    return x

MLP隐藏层 ,就是指中间那一层"升维后的"维度,通常叫 mlp_hidden_dimmlp_dim

为什么 MLP隐藏层维度通常是 model_dim 的4倍?

行业标准和经验:

几乎所有主流的Transformer论文(原始Transformer、BERT、GPT、Llama 等)都采用 mlp_dim = 4 × model_dim 作为"最佳经验值"。

  • 例如,model_dim=1280(主干宽度),则mlp_hidden_dim=5120。
  • 如果model_dim=1024,则mlp_hidden_dim=4096。
为什么4倍?
  • 增加表示能力
    Transformer的自注意力模块更关注token间的信息交换,而MLP负责对每个token内部特征做复杂投影和非线性组合。
    • 扩大MLP宽度,相当于给每个token分配了更高容量的"单token特征处理器",提升每个token表征的复杂性。
  • 实证最优
    各种论文和大规模实验(如GPT-3、Llama设计)表明,4倍左右最能权衡能力提升和显存/效率。
    • 太小,网络表达能力不足。
    • 太大,参数和计算量暴涨,而性能提升边际效益变低。
  • 结构均衡
    1倍(即和model_dim同宽)效果较差,2-4倍提升显著,但大于4、8倍收益变小------"4"是个工程上经过验证的"甜蜜点"。
本质原因

Transformer的原理强调"全局建模"(通过Attention)和"局部特征非线性拓展(MLP)",两者黑盒功能不同。

设置较大MLP扩展,能使每个token的表达在全局混合后再做更复杂的加工,语言模型因此具备更强的语义抽象和记忆能力。


通常各类GPT/LLM的参数分布比例大致如下:

模型 Embedding Transformer Block (含Self-Attn/FFN) 输出层 备注
本模型(d20) 83.9M (~15%) 393.2M (~70%) 83.9M (~15%) 输入/输出占三成,大头在transformer
GPT-3 175B ~617M (<1%) ~99% ~617M (<1%) 层数/宽度极大,embedding占比变小
DeepSeek-V3 1.3B 262.1M (~20%) 1,168M 262.1M (~20%) 更大词表,embedding权重更高占比

我们以大家熟悉的 GPT-2/GPT-3/ChatGPT (对应OpenAI的GPT-3.5/4)、DeepSeek-v3 等主流开源模型进行参数量对比,同时对模型结构主要组成部分参数量进行总结:

模型 总参数数量 层数 维度(d_model) 词表大小 主要组成 Embedding参数 Transformer Block 参数 输出层参数 典型用途/备注
本模型(d20) 560,988,160 20 1280 65,536 GPT-风格 83.9M 393.2M 83.9M 微型研究/教材
GPT-2 Small 117,000,000 12 768 50,257 GPT-2 38.5M 78.9M 38.5M 入门/微型英文
GPT-2 Medium 345,000,000 24 1024 50,257 GPT-2 51.4M 242M 51.4M GPT-2 2/小型
GPT-2 Large 762,000,000 36 1280 50,257 GPT-2 64.2M 633M 64.2M 大号GPT-2
GPT-2 XL 1,542,000,000 48 1600 50,257 GPT-2 80.4M 1,382M 80.4M 超大GPT-2
GPT-3 125M 125,000,000 12 768 50,257 GPT-3 tiny 38.5M 78.9M 38.5M 类GPT-2 small
GPT-3 350M 350,000,000 24 1024 50,257 GPT-3 mini 51.4M 242M 51.4M 类GPT-2 medium
GPT-3 1.3B 1,300,000,000 24 2048 50,257 GPT-3 Small 103M 1,094M 103M 微型GPT-3
GPT-3 6.7B 6,700,000,000 32 4096 50,257 GPT-3 Medium 206M 6,288M 206M 小型GPT-3
GPT-3 175B 175,000,000,000 96 12,288 50,257 GPT-3 617M 173.8B 617M GPT-3 Flagship
DeepSeek-V2 1.3B 1,300,000,000 24 2048 69,376 DeepSeek V2 140M ≈1,020M 140M 支持多语种
DeepSeek-V3 1.3B 1,300,000,000 24 2048 128,384 DeepSeek V3 262.1M ≈1,168M 262.1M 支持更大多语词表
ChatGPT(GPT-3.5) ~6,000,000,000? ~96 ~6,000-12,288? 未公开 OpenAI服务 - - - API/服务大模型

下面是一篇详细的技术文章,系统讨论了Embedding、Transformer Block(含多头注意力和前馈网络)、输出层参数的意义,并给出了本模型(d20)及其他常见模型的参数结构对比。本文适合有一定机器学习或NLP基础的读者作为学习材料或技术参考。


由此可见:

1. Embedding与输出层

  • 小模型、词表大,embedding占比高。 本模型因为词表足够大,embedding和输出层各占到参数总数的15%左右。DeepSeek等支持极大词表(多语种、代码等),embedding和输出层参数占比就更高。
  • 大模型,embedding占比递减。 如GPT-3 175B时,embedding+输出层<2%,绝大多数参数在"骨干"部分。

2. Transformer Block

  • 是大模型参数攀升的根本所在。 层数、宽度提升明显推高参数水平,体现了大模型"主要参数在Block"的本质。
  • **表达能力集中于此。**Block主导了模型的泛化、记忆和推理能力。

3. 模型设计权衡

  • 想节省参数,建议优先压缩embedding/输出层(如共享embedding)。
  • 若追求能力,主要提升block部分规模更为划算,embedding增长仅带来性质改变(如更好多语种/更好rare token coverage)。
  • 假如模型是学生:
    "embedding增长"像是学生买了新字典,能查到更冷门的单词------他不一定理解词的意思,只是"不漏掉"意思。
    "Block变大"像是大脑神经元增多,逻辑推理、理解能力提高了------不单知道生词,还能举一反三、深刻推理。
    GPT-3、GPT-4、Llama等超大模型,参数主力集中在Block部分,即使词表数十万,embedding参数占比也很小
  • 输出层:多数情况下,输出层参数增大主要是解决token覆盖和输出精度问题,本身对模型"理解/推理抽象能力"提升有限。(理解/推理能力由Transformer block主导)
    如果希望支持更多token/语种/领域,必须增大输出层参数;在多语种/跨领域/代码通用大模型场景下,输出层参数是充分表达全词表概率分布不可省略的关键
  • "权重共享"在指的就是输入层(token embedding)和输出层(language modeling head)可以共用同一组权重矩阵。一个几十万词表的模型,embedding/输出各自存一遍参数,非常占空间;共享则直接省掉一半。输入和输出层都用同一组权重(embedding lookup和输出softmax共用)的实现方式:输出层权重等于词嵌入矩阵的转置。

大语言模型参数按Embedding、Transformer Block与输出层划分,清晰反映结构设计思路与资源需求。以本d20模型为例,其参数量与主流小~中型GPT类似,更大词表带来更高词嵌入/输出层参数,而大模型参数主力始终集中于transformer块。

这种参数结构认知,不仅有助于架构选型、资源预估,更有助于理解模型训练和压缩的底层逻辑。未来的多模态、超大规模模型,参数结构依然遵循这一经典"三分法",设计者可针对实际需求优化每一部分参数,提升性价比与能力表现。


常见问题(FAQ)

  • 需要多少 GPU? 最少1张可跑但很慢;推荐 8×H100 跑 $100 档 d20 约 4 小时。
  • 显存不够怎么办? 降低 --device_batch_size;必要时降低 --max_seq_len 或模型深度/宽度,脚本会用梯度累积弥补吞吐。
  • 分词器一定要训练吗? 可直接使用 tiktoken 推理;但自训练更贴合你的数据分布(压缩率更优)。
  • 如何确认启用了 GQA/MQA? 查看 GPTConfig.n_kv_headn_head 的关系;n_kv_head < n_head 即为 GQA/MQA 风格。
  • 为什么推理需要 KV Cache? 否则每步都会重算历史 token 的注意力;KV Cache 复用历史 K/V,只算新增部分。
  • 工具调用安全吗? 受限且有超时,仅允许算术与受限字符串 .count();包含危险片段会被拒绝。
  • 日志与指标怎么看? 设置 WANDB_RUN 使用 wandb 记录;同时会在 ~/.cache/nanochat/report/ 写入 markdown 报告。
  • 如何仅在 CPU 上试跑? 参考上文 PowerShell 样例;把 --depth--max_seq_len--device_batch_size--num_iterations 调小即可。
  • 如何用 WebUI 体验? 训练(或下载)checkpoint 后运行 python -m scripts.chat_web,浏览器访问输出地址。
  • 如何切换更大模型? 提高 --depth 并按 Chinchilla 比例增加数据分片;根据显存调小 --device_batch_size

参考资源

  • 原仓库https://github.com/karpathy/nanochat
  • Transformer论文:Attention Is All You Need
  • Chinchilla论文:Training Compute-Optimal Large Language Models
  • RoPE论文:RoFormer: Enhanced Transformer with Rotary Position Embedding

本教程基于nanochat代码库编写,适合AI初学者深入理解LLM实现。

相关推荐
兔兔爱学习兔兔爱学习4 小时前
ASR+MT+LLM+TTS 一体化实时翻译字幕系统
人工智能·自然语言处理·机器翻译
二向箔reverse4 小时前
用langchain搭建简单agent
人工智能·python·langchain
苦瓜汤补钙4 小时前
论文阅读——Segment Anything(Meta AI)——SAM
论文阅读·图像处理·人工智能·nlp·ai编程
会笑的小熊4 小时前
论文阅读笔记——自注意力机制
深度学习·计算机视觉·自然语言处理
共绩算力4 小时前
混元图像3.0开源原生多模态生图新篇章
人工智能·ai作画·共绩算力
搞科研的小刘选手4 小时前
【经济方向专题会议】第二届经济数据分析与人工智能国际学术会议 (EDAI 2025)
人工智能·机器学习·网络安全·大数据分析·经济·经济数据分析·绿色经济
六边形架构4 小时前
真相!Dify和n8n这两款LLM应用开发平台的最大区别,90%的人都不知道!
人工智能
敲代码的猴先生4 小时前
技术分享 | torch.profiler:利用探针收集模型执行信息的性能分析工具
人工智能·pytorch·经验分享·语言模型·性能优化