LLaMA: Open and Efficient Foundation Language Models 论文阅读

目录

摘要

Introduction

Approach

预训练数据

结构

RMSNorm

SwiGLU

RoPE

Optimizer

高效训练的方法

主要结果

指令微调

模型风险评估

Conclusion

[LLaMA 实现全部代码](#LLaMA 实现全部代码)


Introduction

作为背景,讨论了是否更大的模型训练更大的参数就会有更好的训练效果。

这里引入了缩放定律scaling laws,作者认为缩放定律忽略了推理的成本,相比于训练更快的模型,作者认为应该选择推理更快的模型,因此提出小的 LLM 配大数据训练更好,因为小 LLM 推理更友好。

Approach

预训练数据

LLaMa 预训练数据大约包含 1.4T tokens,对于绝大部分的训练数据,训练期间只使用一次。

下图展示了 LLaMa 预训练数据的占比:

结构

LLaMA 为 decoder-only 结构,和之前其他模型相比最大的3个改进:

  • 对每个Transformer子层的输入使用 RMSNorm 归一化函数进行归一化,而不是对输出进行归一化。
  • 用 SwiGLU 激活函数替换 ReLU 非线性,以提高性能。
  • 删除了绝对位置嵌入,使用 RoPE 旋转位置嵌入。

RMSNorm

可以增强训练稳定性。

代码实现:

python 复制代码
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

SwiGLU

可以提高模型性能。

LLaMa 使用 SwiGLU 激活函数替换 ReLU 非线性以提高性能,SwiGLU 激活函数结合了 Swish 激活函数和 GLU(Gated Linear Unit)的机制。

RoPE

可以更好地建模长序列数据。

来源于苏剑林大神,不直接将位置信息作为向量与词向量相加,而是在注意力机制的 Query(查询)和 Key(键)计算时,将输入向量视为复数,在复数平面上进行旋转,通过旋转操作将位置信息融入输入嵌入中。

优点:

旋转操作可以通过复数操作简化,计算复杂度低

可以捕捉相对位置关系

对长序列友好

LLaMA中 RoPE 的实现代码:

python 复制代码
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
    t = torch.arange(end, device=freqs.device)  # type: ignore  
    freqs = torch.outer(t, freqs).float()  # type: ignore  
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
    return freqs_cis  
  
  
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  
    ndim = x.ndim  
    assert 0 <= 1 < ndim  
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
    return freqs_cis.view(*shape)  
  
  
def apply_rotary_emb(  
    xq: torch.Tensor,  
    xk: torch.Tensor,  
    freqs_cis: torch.Tensor,  
) -> Tuple[torch.Tensor, torch.Tensor]:  
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
    return xq_out.type_as(xq), xk_out.type_as(xk)

Optimizer

使用 AdamW 优化器,超参数如下:β1 = 0.9,β2 = 0.95,最终学习率等于最大学习率的10 %。

如图为7B、13B、33B 和 65B tokens 的模型训练损失:

高效训练的方法

  1. 内存与速度优化:用 xformers 库实现高效因果多头注意力,手动实现 Transformer 层反向传播,选择性保存高计算成本激活值以减少重计算。
  2. 分布式策略:模型并行 + 序列并行减少内存占用;重叠激活计算与 GPU 间通信。
  3. 训练效率:65B 模型用 2048 张 A100(80GB)GPU,1.4T tokens 训练耗时约 21 天。

主要结果

该模块给出 LLaMA 在常识推理、闭卷问答、阅读理解6项具体任务中,其表现优于或匹敌主流模型。

指令微调

该部分指出小样本微调可以进一步提高了模型对指令的follow能力。下图比较了大小适中的模型在MMLU上有指令微调和无指令微调的情况:

模型风险评估

该部分分为Bias, Toxicity and Misinformation,即毒性、偏见、真实性三方面展开,核心结论如下:

  1. 毒性:用 RealToxicityPrompts 评估,模型规模越大毒性越强,如 65B 模型 "基础提示""尊重提示" 毒性得分(0.128、0.141)高于 7B 模型(0.106、0.081),与 OPT 等开源模型趋势一致;
  2. 偏见:通过 CrowS-Pairs 和 WinoGender 评估,65B 模型平均偏见得分(66.6%)略低于 GPT-3、OPT-175B,但宗教领域偏见突出(79.0%,超 OPT 10%);且对 "their/them" 代词共指分辨率高于 "her/she""his/he","gotcha" 案例错误率高,存在性别偏见;
  3. 真实性:用 TruthfulQA 评估,65B 模型 "真实回答""真实且有用" 占比(57%、53%)高于 GPT-3(28%、25%),但仍有幻觉风险。

Conclusion

强调仅用公开可用数据训练大模型就能达到最先进性能,无需依赖专有数据集,这也是本文的标题所在:开源高效的大语言模型。

LLaMA 实现全部代码

python 复制代码
# Copyright (c) Meta Platforms, Inc. and affiliates.  
# This software may be used and distributed according to the terms of the GNU General Public License version 3.  
  
from typing import Optional, Tuple  
from dataclasses import dataclass  
import math  
  
import torch  
from torch import nn  
import torch.nn.functional as F  
  
import fairscale.nn.model_parallel.initialize as fs_init  
from fairscale.nn.model_parallel.layers import (  
    ParallelEmbedding,  
    RowParallelLinear,  
    ColumnParallelLinear,  
)  
  
  
@dataclass  
class ModelArgs:  
    dim: int = 512  
    n_layers: int = 8  
    n_heads: int = 8  
    vocab_size: int = -1  # defined later by tokenizer  
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2  
    norm_eps: float = 1e-5  
  
    max_batch_size: int = 32  
    max_seq_len: int = 2048  
  
  
class RMSNorm(torch.nn.Module):  
    def __init__(self, dim: int, eps: float = 1e-6):  
        super().__init__()  
        self.eps = eps  
        self.weight = nn.Parameter(torch.ones(dim))  
  
    def _norm(self, x):  
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  
  
    def forward(self, x):  
        output = self._norm(x.float()).type_as(x)  
        return output * self.weight  
  
  
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):  
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  
    t = torch.arange(end, device=freqs.device)  # type: ignore  
    freqs = torch.outer(t, freqs).float()  # type: ignore  
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64  
    return freqs_cis  
  
  
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):  
    ndim = x.ndim  
    assert 0 <= 1 < ndim  
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])  
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  
    return freqs_cis.view(*shape)  
  
  
def apply_rotary_emb(  
    xq: torch.Tensor,  
    xk: torch.Tensor,  
    freqs_cis: torch.Tensor,  
) -> Tuple[torch.Tensor, torch.Tensor]:  
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)  
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)  
    return xq_out.type_as(xq), xk_out.type_as(xk)  
  
  
class Attention(nn.Module):  
    def __init__(self, args: ModelArgs):  
        super().__init__()  
  
        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()  
        self.head_dim = args.dim // args.n_heads  
  
        self.wq = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wk = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wv = ColumnParallelLinear(  
            args.dim,  
            args.n_heads * self.head_dim,  
            bias=False,  
            gather_output=False,  
            init_method=lambda x: x,  
        )  
        self.wo = RowParallelLinear(  
            args.n_heads * self.head_dim,  
            args.dim,  
            bias=False,  
            input_is_parallel=True,  
            init_method=lambda x: x,  
        )  
  
        self.cache_k = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
        self.cache_v = torch.zeros(  
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)  
        ).cuda()  
  
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  
        bsz, seqlen, _ = x.shape  
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)  
  
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)  
  
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  
  
        self.cache_k = self.cache_k.to(xq)  
        self.cache_v = self.cache_v.to(xq)  
  
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk  
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv  
  
        keys = self.cache_k[:bsz, : start_pos + seqlen]  
        values = self.cache_v[:bsz, : start_pos + seqlen]  
  
        xq = xq.transpose(1, 2)  
        keys = keys.transpose(1, 2)  
        values = values.transpose(1, 2)  
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  
        if mask is not None:  
            scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)  
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)  
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)  
        output = output.transpose(  
            1, 2  
        ).contiguous().view(bsz, seqlen, -1)  
  
        return self.wo(output)  
  
  
class FeedForward(nn.Module):  
    def __init__(  
        self,  
        dim: int,  
        hidden_dim: int,  
        multiple_of: int,  
    ):  
        super().__init__()  
        hidden_dim = int(2 * hidden_dim / 3)  
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)  
  
        self.w1 = ColumnParallelLinear(  
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  
        )  
        self.w2 = RowParallelLinear(  
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x  
        )  
        self.w3 = ColumnParallelLinear(  
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x  
        )  
  
    def forward(self, x):  
        return self.w2(F.silu(self.w1(x)) * self.w3(x))  
  
  
class TransformerBlock(nn.Module):  
    def __init__(self, layer_id: int, args: ModelArgs):  
        super().__init__()  
        self.n_heads = args.n_heads  
        self.dim = args.dim  
        self.head_dim = args.dim // args.n_heads  
        self.attention = Attention(args)  
        self.feed_forward = FeedForward(  
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of  
        )  
        self.layer_id = layer_id  
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)  
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)  
  
    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):  
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)  
        out = h + self.feed_forward.forward(self.ffn_norm(h))  
        return out  
  
  
class Transformer(nn.Module):  
    def __init__(self, params: ModelArgs):  
        super().__init__()  
        self.params = params  
        self.vocab_size = params.vocab_size  
        self.n_layers = params.n_layers  
  
        self.tok_embeddings = ParallelEmbedding(  
            params.vocab_size, params.dim, init_method=lambda x: x  
        )  
  
        self.layers = torch.nn.ModuleList()  
        for layer_id in range(params.n_layers):  
            self.layers.append(TransformerBlock(layer_id, params))  
  
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)  
        self.output = ColumnParallelLinear(  
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x  
        )  
  
        self.freqs_cis = precompute_freqs_cis(  
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2  
        )  
  
    @torch.inference_mode()  
    def forward(self, tokens: torch.Tensor, start_pos: int):  
        _bsz, seqlen = tokens.shape  
        h = self.tok_embeddings(tokens)  
        self.freqs_cis = self.freqs_cis.to(h.device)  
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]  
  
        mask = None  
        if seqlen > 1:  
            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)  
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)  
  
        for layer in self.layers:  
            h = layer(h, start_pos, freqs_cis, mask)  
        h = self.norm(h)  
        output = self.output(h[:, -1, :])  # only compute last logits  
        return output.float()
相关推荐
小鸡吃米…5 小时前
机器学习 - K - 中心聚类
人工智能·机器学习·聚类
沈浩(种子思维作者)6 小时前
真的能精准医疗吗?癌症能提前发现吗?
人工智能·python·网络安全·健康医疗·量子计算
njsgcs7 小时前
ue python二次开发启动教程+ 导入fbx到指定文件夹
开发语言·python·unreal engine·ue
io_T_T7 小时前
迭代器 iteration、iter 与 多线程 concurrent 交叉实践(详细)
python
华研前沿标杆游学7 小时前
2026年走进洛阳格力工厂参观游学
python
Carl_奕然7 小时前
【数据挖掘】数据挖掘必会技能之:A/B测试
人工智能·python·数据挖掘·数据分析
AI小怪兽7 小时前
基于YOLOv13的汽车零件分割系统(Python源码+数据集+Pyside6界面)
开发语言·python·yolo·无人机
wszy18098 小时前
新文章标签:让用户一眼发现最新内容
java·python·harmonyos
Eric.Lee20218 小时前
python实现 mp4转gif文件
开发语言·python·手势识别·手势交互·手势建模·xr混合现实
EntyIU8 小时前
python开发中虚拟环境配置
开发语言·python