【llm对话系统】大模型 Llama 源码分析之并行训练方案

1. 引言

训练大型语言模型 (LLM) 需要巨大的计算资源和内存。为了高效地训练这些模型,我们需要采用各种并行策略,将计算和数据分布到多个 GPU 或设备上。Llama 作为当前最流行的开源大模型之一,其训练代码中采用了多种并行技术。本文将深入 Llama 的训练代码,分析其并行训练方案,主要关注参数并行部分结构参数共享

2. 并行训练策略概述

常见的并行训练策略包括:

  • 数据并行 (Data Parallelism, DP):将数据分成多个 batch,每个 GPU 处理一个 batch,所有 GPU 使用相同的模型副本。
  • 模型并行 (Model Parallelism, MP):将模型分成多个部分,每个 GPU 负责模型的一部分。
  • 流水线并行 (Pipeline Parallelism, PP):将模型的不同层分配到不同的 GPU 上,形成一个流水线。
  • 张量并行 (Tensor Parallelism, TP):将模型的张量 (例如,权重矩阵) 分片到多个 GPU 上。
  • 序列并行 (Sequence Parallelism, SP): 将序列长度分片到多个 GPU 上。

Llama 主要采用了数据并行张量并行 ,以及一些结构参数共享的优化。

3. Llama 中的参数并行方案

Llama 使用了 ZeRO (Zero Redundancy Optimizer) 技术,这是一种强大的内存优化方法,它结合了数据并行和模型并行。ZeRO 的核心思想是将模型状态 (权重、梯度和优化器状态) 分片到多个 GPU 上,从而减少每个 GPU 的内存占用。

ZeRO 有三个阶段:

  • ZeRO-1 (Optimizer State Partitioning):将优化器状态 (例如,Adam 的动量和方差) 分片。
  • ZeRO-2 (Gradient Partitioning):在 ZeRO-1 的基础上,将梯度也分片。
  • ZeRO-3 (Parameter Partitioning):在 ZeRO-2 的基础上,将模型参数也分片。

Llama 主要使用了 ZeRO-3,将模型参数、梯度和优化器状态都分片到多个 GPU 上。

3.1 参数分片计算

在 Llama 的训练代码中, 以 torch.distributed.fsdp 库为例 (Fully Sharded Data Parallel, FSDP),它实现了 ZeRO-3 的功能。

以下是一个简化的 FSDP 参数分片示例:

python 复制代码
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
import functools

# 假设我们有一个简单的 Transformer 模型
class TransformerLayer(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(hidden_dim, 4 * hidden_dim)
        self.linear2 = torch.nn.Linear(4 * hidden_dim, hidden_dim)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.nn.functional.relu(x)
        x = self.linear2(x)
        return x

# 初始化分布式环境
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")

# 模型和优化器
hidden_dim = 768
model = TransformerLayer(hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 使用 FSDP 包装模型
# 使用自动包装策略
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerLayer,},
)
model = FSDP(model, fsdp_auto_wrap_policy=auto_wrap_policy, device_id=device)

# 模拟训练数据
x = torch.randn(1, 10, hidden_dim, device=device)

# 前向传播
y = model(x)

# 反向传播
loss = y.sum()
loss.backward()

# 优化器更新
optimizer.step()

# 清空梯度
optimizer.zero_grad()

print(f"Rank {rank}: 训练完成")

代码解释:

  1. 初始化分布式环境dist.init_process_group("nccl") 初始化分布式训练环境,使用 NCCL 后端。
  2. 模型和优化器:定义一个简单的 Transformer 层和 Adam 优化器。
  3. FSDP 包装 :使用 FullyShardedDataParallel 包装模型。这会将模型参数分片到多个 GPU 上。
    • transformer_auto_wrap_policy 会自动将模型的每一层都用 FSDP 包装起来。
  4. 前向和反向传播:执行模型的前向和反向传播。
  5. 优化器更新 :执行优化器的 step 方法。
  6. 清空梯度:清空梯度以进行下一次迭代。

运行方式:

你需要使用 torchrun(或 torch.distributed.launch)来启动这个脚本,例如:

bash 复制代码
torchrun --nproc_per_node=4 fsdp_example.py

这将使用 4 个 GPU 来训练模型。

原理说明:

  • 参数分片 :在 model = FSDP(model, ...) 这一步,模型参数被分片到 4 个 GPU 上。每个 GPU 只存储一部分参数。
  • All-gather:在前向传播过程中,当需要完整的参数进行计算时(例如,矩阵乘法),FSDP 会自动执行 all-gather 操作,将所有 GPU 上的参数片段收集起来,组成完整的参数。
  • Reduce-scatter:在反向传播过程中,梯度也是分片的。计算完梯度后,FSDP 会执行 reduce-scatter 操作,将每个参数的梯度片段 reduce 到对应的 GPU 上。
  • 优化器更新:每个 GPU 使用自己分片到的参数和梯度来更新优化器状态。

通过这种方式,FSDP 显著减少了每个 GPU 的内存占用,使得训练更大的模型成为可能。

4. Llama 中的部分结构参数共享

除了参数分片,Llama 还采用了一些结构参数共享的优化,以进一步减少内存占用和提高训练效率。

例如在 Transformer 的多头注意力 (Multi-Head Attention) 机制中,不同 head 的 query, key, value 矩阵的计算通常是独立的。Llama 通过共享 key 和 value 矩阵,减少了参数量和计算量。更具体地说,llama使用了分组注意力机制(Grouped-Query Attention)。

4.1 分组注意力 (Grouped-Query Attention, GQA)

GQA 介于标准的多头注意力 (MHA) 和 Multi-Query Attention (MQA) 之间。

  • MHA: 每个 head 都有独立的 Q, K, V 矩阵。
  • MQA: 所有 head 共享 K, V 矩阵,只有 Q 矩阵是独立的。
  • GQA: 将 head 分成多个组,每个组内的 head 共享 K, V 矩阵。

例如:

假设我们有 8 个 head,可以将它们分成 4 个组,每个组 2 个 head。这样,我们就只有 4 个 K 矩阵和 4 个 V 矩阵,而不是 8 个。

代码示例 (简化版)

python 复制代码
import torch
import torch.nn as nn

class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        super().__init__()
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = embed_dim // num_heads
        self.group_size = num_heads // num_groups

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        # 共享 K, V 矩阵
        self.k_proj = nn.Linear(embed_dim, num_groups * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, num_groups * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # 计算 Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_groups, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_groups, self.head_dim)

        # 将 head 分组
        q = q.view(batch_size, seq_len, self.num_groups, self.group_size, self.head_dim)

        # 计算注意力
        attn_scores = torch.einsum("bqlgd,bklhd->bqgkh", q, k) / (self.head_dim ** 0.5)
        attn_probs = attn_scores.softmax(dim=-1)
        attn_output = torch.einsum("bqgkh,bklhd->bqlgd", attn_probs, v)

        # 拼接 head
        attn_output = attn_output.reshape(batch_size, seq_len, embed_dim)

        # 输出投影
        output = self.out_proj(attn_output)

        return output

# 示例
embed_dim = 768
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(1, 10, embed_dim)
y = model(x)
print(y.shape)

代码解释:

  1. num_groups:将 head 分成多少个组。
  2. k_projv_proj:只输出 num_groups 个 head 的 K 和 V 矩阵。
  3. q 分成 num_groups 个组,每个组 group_size 个 head。
  4. 计算注意力时,每个组内的 head 共享 K 和 V 矩阵。

GQA 的优势:

  • 减少参数量:K 和 V 矩阵的数量减少了。
  • 减少计算量:计算注意力时,每个 head 需要处理的 K 和 V 数量减少了。
  • 性能接近 MHA:实验表明,GQA 的性能接近 MHA,明显优于 MQA。

llama的GQA实现在llama/model.py文件中,class Attention(nn.Module) 类下的forward函数中,更具体地,体现在self.num_headsself.num_kv_heads的参数上, 分别控制querykvhead数量,num_kv_heads小于num_heads

相关推荐
知识鱼丸1 小时前
自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测
人工智能
憨猪在度假2 小时前
Rk3588芯片介绍(含数据手册)
人工智能
西猫雷婶3 小时前
python学opencv|读取图像(五十二)使用cv.matchTemplate()函数实现最佳图像匹配
人工智能·python·opencv·计算机视觉
2301_793069823 小时前
OpenCV 图像旋转
人工智能·opencv·计算机视觉
纠结哥_Shrek3 小时前
基于最近邻数据进行分类
人工智能·分类·数据挖掘
kakaZhui4 小时前
【llm对话系统】大模型 Llama 源码分析之 Flash Attention
人工智能·chatgpt·aigc·llama
Melancholy 啊4 小时前
细说机器学习算法之ROC曲线用于模型评估
人工智能·python·算法·机器学习·数据挖掘
爱研究的小牛5 小时前
Deepseek技术浅析(二):大语言模型
人工智能·机器学习·语言模型·自然语言处理·aigc
编程武士5 小时前
OpenCV 版本不兼容导致的问题
人工智能·opencv·计算机视觉