Llama3.0论文学习笔记: The Llama 3 Herd of Models

1. 写在前面

今天分享Llama3.0的论文,2024.7月来自Meta的Llama团队,2025年1月DeepSeek R1出现之后,其风头显然已经盖住了Llama3,这时候整理Llama3感觉有点赶不上潮流了,但是我还是想整理下Llama3.0,原因是刚好春节的时候读了下原论文(起源是节后在公司做分享),另外一个是Llama3的技术比较成熟,很多大模型的语言部分都是Llama的架构,因此想做为入门大模型的第一个实践模型。

Llama3的技术报告也是长达94页,读完之后,还是有一种醍醐灌顶的感觉,里面还是写了非常多干货的,比如如何从零预训练一个大语言模型,后处理怎么做,海量的训练数据怎么清洗,如何用scaling laws定律指导模型的设计与训练等, 学习完之后,受益匪浅,这些也可以当作后面训练大模型的一些经验来用,所以想先用一篇笔记把Llama3的一些知识整理一下。

这篇论文介绍了Meta公司开发的新一代语言模型Llama 3,该模型支持多语言、代码、推理和工具使用。最大的模型是一个有405B参数的Dense Transformer架构(现在很多大的模型开始走MOE架构,比如DeepSeek系列), 最长能支持128K的上下文。 在很多任务上都和GPT-4的效果差不多。

下面用一个表来提炼论文内容:

模型 Llama 3.1 405B
算力 3.8 ∗ 1 0 25 3.8 * 10^{25} 3.8∗1025FLOPS
数据量 15.6 15.6 15.6 T(llama 2的50倍)
支持上下文 128 128 128K(影响RAG的使用)
特点 多语言、代码、推理、工具使用
训练架构 Dense Transformer, 没有用MOE, 提升稳定性
训练方法 Pre-train (初始预训练、上下文预训练, 退火训练) Post train(没有采用RLHF(基于人类反馈的强化学习),用的SFT(监督微调), RS(拒绝采样), DPO(直接偏好优化))
模型版本 Pre-trained Llama3(8B、70B、405B) Post-trained Llama3 Instruct(8B、70B、405B) Llama Guard 3 (输入输出方面加了一些安全的策略)

Ok, let's go!

大纲如下:

  • [1. 写在前面](#1. 写在前面)
  • [2. What](#2. What)
  • [3. Why](#3. Why)
  • [4. How](#4. How)
    • [4.1 预训练](#4.1 预训练)
      • [4.1.1 预训练数据](#4.1.1 预训练数据)
        • [4.1.1.1 数据清洗](#4.1.1.1 数据清洗)
        • [4.1.1.2 数据混合](#4.1.1.2 数据混合)
        • [4.1.1.3 退火数据](#4.1.1.3 退火数据)
      • [4.1.2 模型架构](#4.1.2 模型架构)
        • [4.1.2.1 Scaling Laws](#4.1.2.1 Scaling Laws)
      • [4.1.3 训练的基础配置](#4.1.3 训练的基础配置)
      • [4.1.4 训练丹方](#4.1.4 训练丹方)
    • [4.2 后训练](#4.2 后训练)
      • [4.2.1 奖励模型](#4.2.1 奖励模型)
      • [4.2.2 监督微调(SFT)](#4.2.2 监督微调(SFT))
      • [4.2.3 DPO训练优化](#4.2.3 DPO训练优化)
      • [4.2.4 模型平均](#4.2.4 模型平均)
      • [4.2.5 能力](#4.2.5 能力)
  • [5. 应用](#5. 应用)

2. What

基础模型是通用的语言、视觉、语音或其他模态的模型,用于支持各种各样的AI任务。

现代基础模型的发展包含两个阶段:

  1. 预训练阶段: 此阶段中,模型使用 诸如下一个词预测或标题生成等直接任务在大规模数据上训练(GPT1,2,3)
  2. 后训练阶段: 此阶段,模型被调整以遵循指令、与人类偏好保持一致, 并提升特定的能力,比如编码和推理(Instruct GPT)

Llama3 模型群是一套新的基础语言模型,该模型支持多语言、编程、推理和工具使用。最大的模型有405B参数, 在128K token的山下文窗口中处理信息。

开发高质量的基础模型,有3个关键因素:数据、规模、和复杂度管理。在Llama模型更新迭代过程,注重在3个关键因素上进行发力:

  1. 数据: 相比于早期的Llama模型,提升了预训练和后训练的数据量和质量, 为预训练数据开发了更细致的训练过程,后训练采用了更严格的数据过滤方法。 Llama 3是15T的多语言语料库进行的预训练, Llama 2是1.8T的语料库进行的训练。
  2. 规模 :Llama3系列模型的参数规模远大于Llama 2, 最大的一个模型参数405B,计算量 3.8 ∗ 1 0 25 3.8*10^{25} 3.8∗1025 FLOPs, 是Lamma 2的50倍。
  3. 复杂度管理:用稠密的Transformer架构而不是MOE架构,后处理采用常规的监督学习微调(SFT),拒绝采样(rejection sampling, RS), 直接偏好优化(direct preference optimization, DPO), 相对简单,较稳定,容易扩展

3. Why

那么Llama3为啥会这么火呢? 首先看下Llama 3模型的表现

  • 多领域任务表现突出 :在多个基准测试中, Llama 3展示了强大的性能。 在常识推理,知识问答,阅读理解,数学推理,编码等任务,不同参数规模的Llama 3都有出色的表现。 Llama 3 405B的模型在多个任务上与GPT4表现相当
    • MMLU: 各种考试里面的多选题,考察的是知识面,各种知识记住了就会比较猛,模型越大可能越占优势, 近几年大模型可能overfit了

    • IFEval: 看模型对各种指令的一个理解能力

      llama模型在数学,代码,推理等方面的能力很强

  • 多语言能力 :支持至少8种语言,在多语言任务上表现良好
  • 工具使用与多模态扩展:能使用多种工具,如搜索引擎,pyhon解释器和数学计算引擎等,零样本工具使用任务上表现出色。 通过组合方法为模型添加图像,视频和语音等能力
  • 开源
  • 支持128K上下文

整体来看,70B其实就够用, 各个任务上,8B和70B差的有点大,但70B和405B就没有那么大差距了。但405B的模型训练和推理成本会比较大,70B是一个性价比较高的模型。

4. How

Llama 3的模型架构

开发主要包括两个阶段:

  1. 预训练阶段: 首先将一个大型的、多语言的文本语料库转换为离散的token,在得到的数据上执行下一个token预测的任务预训练一个大型语言模型(LLM)。在语言模型预训练阶段,模型学习语言的结构,并从它正在 "阅读" 的文本中获得大量关于世界的知识。
  2. 后训练阶段:预训练的语言模型对语言有丰富的理解,但它还没有遵循指令或以我们期望的方式行事。在几轮中将模型与人类反馈对齐,每轮都涉及对指令调整数据和直接偏好优化的监督微调。

由此产生的模型具有丰富的功能集。他们可以用至少八种语言回答问题,编写高质量的代码,解决复杂的推理问题,并使用开箱即用或zero-shot的方式使用工具。

接下来,还加了图像、视频、语音等能力。

  • 多模态编码器预训练:在大量图像 - 文本对上训练图像编码器。这教会模型视觉内容和自然语言中对该内容的描述之间的关系
  • 视觉适配器训练:训练一个适配器,将预训练的图像编码器集成到预训练的语言模型中。
    • 适配器由一系列交叉注意力层组成,这些层将图像编码器表示输入到语言模型中。
    • 适配器在文本 - 图像对上进行训练。使图像表示与语言表示对齐。在适配器训练期间,更新图像编码器的参数,冻结语言模型参数。
    • 我们还在成对的视频 - 文本数据上训练图像适配器之上的视频适配器。这使模型能够跨帧聚合信息。
  • 语音适配器训练:通过适配器将语音编码器集成到模型中,该适配器将语音编码转换为可以直接馈送到微调语言模型中的令牌表示。适配器和编码器的参数在监督微调阶段联合更新,以实现高质量的语音理解

下面是每个步骤的详细过程。

4.1 预训练

预训练涉及步骤:

  1. 大规模语料库的管理和过滤
  2. 模型架构的开发和确定模型大小以及模型用多少数据的scaling laws找出来
  3. 大规模高效预训练技术的开发
  4. 预训练流程的原则

4.1.1 预训练数据

Llama 3 成功很大的一个原因是Data方面的工作做了很多,清洗,构造,配比等。 数据源 : 2023年底之前的网上的数据源

4.1.1.1 数据清洗

使用了多个去重的策略以及数据清理策略, 提高数据质量

  1. PII和安全过滤

    1. 非法网站、敏感网站
    2. 大量personally identifiable informationPII(个人隐私)的网站
    3. Meta安全标准被列为有害的域
    4. 成人内容
  2. 文本的提取和清洗

    1. 网页数据 -> html解析器, 抽取文章,去除广告,导航栏之类的信息
    2. 保留图片的文本属性
    3. 保留数学内容(OCR识别)和代码内容
    4. 去除标记符
    5. Markdown有害,也去除 ? 模型回答的公式,图片等都是markdown, 训练如果不用markdown的话,感觉不太合理
  3. 数据去重:

    1. url层面: 同一个url可能爬了多次,保留最新的
    2. 文档层面: 相似的文档只保留一份, 采用min-hash技术(求两篇文章的Jaccard距离)
    3. 行级别的去重: 3000w个文档里面,每一行提取,如果一行的文字出现超过6次,移除掉。目的: 删除网站的残留模板,导航,警告等
  4. 启发式过滤:删除额外的低质量文档、异常值和重复过多的文档

    1. 用重复的 n-gram 覆盖率来删除包含重复内容(如日志记录或错误消息), 上面行过滤可能过滤不掉
    2. 脏话统计, 有太多脏字的文档去掉
    3. 与训练语料库分布相比,统计每篇文章token分布 Kullback-Leibler KL散度来过滤掉包含过多异常值token的文档(如果某篇文章里面的token分布和其他文章相差的很远,认为是异常token)
  5. 基于模型的质量筛选: 训练一些模型分类器,给处理之后的文档进行打分 或 打上tag,根据实验结果进行更细的过滤

    1. 分类模型: fasttext模型, 识别是否被维基百科引用,如果是,质量高
    2. Llama 2标注的网页数据(提示词工程),训练DistilledRoberta模型(Bert的简化版本), 做二分类, 直接判断质量高低
    3. 上面两个分类模型的结果取并集,两个模型只要一个判定质量高,则保留
  6. 代码和推理的数据单独抽取处理:

    1. 构建了提取代码和数学相关网页的特定领域管道。
    2. 代码和推理分类器都是在 Llama 2 注释的 Web 数据上训练的 DistilRoberta 模型。与上面提到的通用质量分类器不同,对包含数学推导、STEM 区域中推理和与自然语言交错的代码的目标网页进行提示调整
  7. 多语言数据处理(非英语)

    1. 分类器,识别成176种语言
    2. 文档去重,行去重
    3. 特定语言的启发式方法和基于模型的过滤器删除低质量文档
    4. 基于Llama 2的多语言分类器对多语言文档质量排序,确保高质量的内容得到优先处理
4.1.1.2 数据混合

数据来自不同数据源, 每个数据源的数据遵循一个配比。

原则: 高质量的数据多用,低质量的数据少用

决定数据混合的方法是: 知识的分类和scaling law的实验

  • 知识的分类: 开发了一个分类器来对网络数据中包含的信息类型进行分类(类似于打一个tag),以更有效地确定数据组合。我们使用这个分类器对网络上表现过度的数据类别进行下采样,例如艺术和娱乐。
  • Scaling law: 为了确定最佳数据混合,我们执行缩放定律实验,在该实验中,在数据混合上训练几个小模型,并用它来预测大模型在该混合上的性能。对于不同的数据混合,我们多次重复此过程以选择新的数据混合候选。随后,我们在该候选数据混合上训练一个更大的模型,并在几个关键基准上评估该模型的性能。

最终的数据组合包含大约 50% 的对应于一般知识的token、25% 的数学和推理token、17% 的代码token和 8% 的多语言token

4.1.1.3 退火数据

根据经验,我们发现对少量高质量代码和数学数据进行退火可以提高关键基准测试上预训练模型的性能。

退火的意思是说: 对少量高质量的代码、数学数据(比如学术的数据集,评测榜的评测集)单独抽出来, 在经过大量网页数据训练完之后, 再把模型的学习率打开, 再用这些高质量的数据再训练一下(类似于考试之前背一下考题)。 如果这些数据直接和大量网页数据混合起来, 基本上就看不见了。

退火的策略还可以帮助去评估一个训练集的好坏:

  • 假设已经有一个在DataA上训练到一半的模型, 现在要评估DataB对模型是否有用
  • 70%的DataA + 30%的DataB 混合成一个新的DataC数据集,采用退火的方式再训练下模型,看模型效果是否有提升

4.1.2 模型架构

Llama3采用的稠密的,标准的Transformer架构(解码器),和Llama2基本保持一致。

405B这个规模的dense模型在如今MoE的潮流中显得有些"复古"。Meta对此给出了解释:不做成MoE模型是因为要追求能力的最大化(通常来说,相同总参数量下dense模型还是比MoE要强一些的),同时使用标准的Transformer模型可以让训练更加稳定,毕竟这样大规模的训练成本巨大,如果中间训炸了还是比较麻烦的。包括在post-training中使用的supervised finetuning(SFT),rejection sampling(RS),and direct preference optimization(DPO)都是经受住了许多考验,证明有效的方案。看起来这里在路线的选择上,Meta倾向于保守一些。 当然,还有内幕是说,当时训练MOE架构,但没训练收敛,所以没放出来。

和2017年的transformer块有点小区别,比如位置编码采用了一种循环编码的方式, LayerNorm换成了RMSNorm, ReLU激活函数换成了SwiGLU等。

性能的提升主要是来自于数据的质量,数量和多样性, 训练规模变大, 对模型结构并没有做fancy的调整。

小的修改如下:

  1. 使用分组查询注意力(GQA)和 8 个键值头,以提高推理速度并减少解码期间键值缓存的大小。理解Attention:从起源到MHA,MQA和GQA

    1. GQA的原理:主要基于对传统多头注意力机制(Multi-Head Attention,MHA)的改进,旨在在保持模型性能的同时,提升推理速度和减少内存占用

    2. 在传统的 MHA 中,每个头都有独立的查询(Query)、键(Key)和值(Value)线性投影。对于输入序列,会为每个头分别计算注意力,这虽然能让模型从不同表示子空间捕捉信息,但随着头数增多,计算量和内存需求大幅增加。例如,在处理长序列时,多头注意力计算的复杂度较高,会成为模型推理速度和内存使用效率的瓶颈。

    3. GQA 引入了分组的概念,在保持查询头数量不变的情况下,将键值头的数量进行分组。每个查询头可以访问所有组的键值信息,但在计算注意力时,不同查询头共享部分键值头的计算结果。这意味着,虽然查询头可以从多个不同的表示子空间获取信息,但键值头的计算被共享,从而减少了计算量和内存占用

      python 复制代码
      # 假设开始的输入 (1, 256, 64)
      
      # 计算得到Q, K, V
      # 8个Q, 2个K, 2个V
      query = torch.randn(1, 256, 8, 64)
      key = torch.randn(1, 256, 2, 64)
      value = torch.randn(1, 256, 2, 64)
      
      # Q分成四组
      num_head_groups = query.shape[2] // key.shape[2]
      
      # 维度交换,方便后面分组和计算
      query = rearrange(query, "b n h d -> b h n d")  # query:从(1, 256, 8, 64)变为(1, 8, 256, 64)。
      key = rearrange(key, "b s h d -> b h s d") # key:从(1, 256, 2, 64)变为(1, 2, 256, 64)
      value = rearrange(value, "b s h d -> b h s d")  # value:从(1, 256, 2, 64)变为(1, 2, 256, 64)
      
      # query:从(1, 8, 256, 64)变为(1, 4, 2, 256, 64)。
      # 这里将原来 8 个头的查询张量按照每组 2 个头分成了 4 组,增加了一个分组维度g。
      query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups) 
      
      # 计算分数
      # 经过einsum计算后维度为(1, 2, 256, 256)
      # 计算过程中,对于查询的每一组(4 组),每组内的 2 个头会与键的 2 个头进行矩阵乘法,最终通过求和等操作得到注意力分数。这里einsum操作会沿着最后一个维度(d)进行矩阵乘法,然后对分组维度g进行求和等操作,所以最终结果的头数量和键的头数量一致,为 2。
      scores = einsum(query, key, "b g h n d, b h s d -> b h n s") 
      
      scale = query.size(-1) ** 0.5
      attention = F.softmax(scores / scale, dim=-1) # (1, 2, 256, 256)
      out = einsum(attention, value, "b h n s, b h s d -> b h n d")  # (1, 2, 256, 64)
      out = rearrange(out, "b h n d -> b n h d")  # (1, 256, 2, 64)
      
      # 后两维reshape,再接FFN即可
      # 合并后两个维度
      merged_output = rearrange(gqa_output, 'b n h d -> b n (h d)')  # (1, 256, 128)
      
      # FFN
  2. 使用注意力掩码来防止同一序列中不同文档之间的self-attention。 即一个序列可能是多篇文档组成,计算self-att的时候,每篇文档只需要和内部的token计算attention即可。这种变化在标准预训练中影响有限,但发现它在非常长序列的持续预训练中很重要。

  3. 使用带有 128K token的词汇表。token词汇表将来自 tiktoken3 分词器的 100K token与 28K 额外token相结合,以更好地支持非英语语言,支持更多种类的词元输入。

  4. 提高了压缩率(1个token对应的字符个数), 从llama 2的3.17/token -> 3.94/token, 意味着同样的文章,有更短的上下文, 对长文本的支持会更好,对RAG的支持会更好。

  5. 调整了旋转位置编码(RoPE)的基础频率,以更好地处理长序列输入。把RoPE的base frequency增大到500,000,按《Effective long-context scaling of foundation models》的结果,这个数值足够支持32,768长度的窗口了, 位置编码可以参考理解LLM位置编码:RoPE

4.1.2.1 Scaling Laws

Scaling Laws(缩放定律)是描述大型语言模型性能与模型规模(参数数量)、训练数据量、计算资源消耗之间量化关系的规律。

一定范围内,模型性能(以损失函数衡量)随参数数量、训练数据量、计算资源增加而提升,通常呈幂律下降关系。但参数和数据量增加到一定程度,性能提升会变缓。

应用:

  • 辅助模型设计,根据资源和目标预测所需参数和数据量。
  • 指导训练资源分配,确定不同阶段资源投入及参数与数据扩充比例。
  • 作为性能评估基准,判断模型设计合理性与提升空间。

论文中通过scaling laws去确定模型的大小, 一般是拿小模型上的表现来推测下在大模型上的表现

论文指出之前用scaling law去指导模型大小存在的问题:

  • 现有的扩展定律通常只预测下一个标记的预测损失,而不是特定的基准性能。
  • 缩放定律可能具有噪声和不可靠性,因为它们是基于以较小的计算预算进行的预训练运行开发的

论文中实施了一种两阶段方法来制定准确预测下游基准性能的扩展定律:

  1. 首先在计算最优模型对下游任务(可以是任意关心的特定任务)的负对数似然(预测下一个token的loss误差)与训练 FLOPS 之间建立相关性。
  2. 接下来,我们将下游任务的负对数似然与任务准确性相关联(这样就和模型解耦开了),同时使用缩放定律模型和用更高计算 FLOP 训练的旧模型。在此步骤中,我们专门利用了 Llama 2 系列模型的训练结果。

下面是预测的结果:

左图: 不同算力规模下, 不同参数的模型,训练token的大小与损失的曲线对比图(每个点是一个特定的模型大小),结论如下:

  • 同一算力的条件下, 模型参数越小, 可训练的tokens越大, 此时的验证损失先变小后增大 ,有个平衡点。 下面给出了一个推导, 来得出了模型算力,token数和模型尺寸3者的一个关系。结论: F = 6 N D F=6ND F=6ND

  • 算力不断翻倍的情况下, 验证loss整体是一个不断下降的结果, 最终假设token数量和算力之间符合这样的一个关系:

  • 模型变得比较大的时候, training tokens与验证损失的变化曲线相对平衡了,意思是最佳的模型大小或最佳样本大小相对来说, 大一点小一点,没多大关系,不需要那么精确

有了上面的结论, 对模型的尺寸指导有啥意义呢? 根据左图, 可以找到不同算力下面的最佳的平衡点, 根据最佳平衡点, 拟合等式,求出未知数 A A A和 a a a( A A A和 a a a的拟合,不同的模型不一样,跟数据和模型超参有关系,整体符合这个规律,需要不停的实验)

有了token数N和算力C之后,根据之前的公式 C = 6 N M C = 6NM C=6NM, 就能大概推算出我的模型参数量大概是多少,确定大概能训练多大的模型。

  1. 不断调整算力预算,每一次预算固定之后, 在小模型上进行实验, 把token和模型的平衡点找到,最后得到一条平衡曲线, 根据平衡曲线把A和a拟合出来。
  2. 根据最终的算力预算,计算出能使用的token数量。
  3. 根据token数量, 估算出训练的模型大小。

最后估算出Llama3的405B模型的大小。上面这个思想很重要

右边图是算力和训练数据之间的关系,大概能估算出, 我有这么多的算力,大概需要多少训练数据才能使得模型收敛。

接下来,是在一个特定的benchmark上(推理上一个评测集ARC challenge)的一个应用

左边图是用了几个scaling law的小模型, 可视化了算力与验证loss的一个趋势, 结论:

  • 算力越大,模型越大, loss线性下降。
  • 通过小模型去预测405B模型的算力与loss的关系,预测出的scaling law 与真实的405B的差距还是很小的
    右边图是验证loss与准确率之间的一个关系图:
  • 依然是用了llama3的几个小模型, Llama 2的模型,预测出了405B模型的loss与准确率的关系, 相差也不大。 并且后面开始收敛了,再增大模型提升效果不明显了。

通过上面的实验还可以得出:

  • 不同的benchmark的曲线变化不一样, 论文里面认为针对特定benchmark做曲线拟合很有必要

4.1.3 训练的基础配置

这块简单过一下, 一般没有资源支持这种大规模训练。

  • 计算资源: 16K张 H100, 80G Nvlink, 一机8卡
  • 存储: 240PB的SSD, 7000台机器存储
  • 网络设计,负载均衡等
  • 训练时候的各种并行方式
    • 张量并行TP:个别权重的张量分割成不同设备的多个块
    • 流水线并行PP: 通过层降模型垂直划分成不同阶段,在不同设备并行处理完整模型流水线的不同阶段
    • 数据并行DP: 对模型、优化器和梯度分片,数据分片,在多个GPU上处理
    • 上下文并行CP: 对输入上下文分成段, 减少长序列输入的内存瓶颈

4.1.4 训练丹方

这个还是很重要的。主要说了Llama3是怎么预训练出来的。

主要有3步:

  1. 初始的预训练
    1. A d a m W AdamW AdamW算法,对Llama3 405B预训练,峰值学习率 8 ∗ 1 0 − 5 8*10^{-5} 8∗10−5, 前8000steps预热(学习率0线性到峰值), 后面cos下降到 8 ∗ 1 0 − 7 8*10^{-7} 8∗10−7(12000 steps)
    2. 训练早期使用较小的 batch size 来提高训练稳定性,随后增加 batch size 以提高效率
      1. 初始: 4M个token的batch size, 长度4096
      2. 预训练252M token之后, batch size 变成 8M个token, 序列长度8192
      3. 预训练2.87 T token之后, batch size 变成 16M 的token, 序列长度8192(目前大部分语言模型支持的序列长度是8K的原因,就是训练的时候用的8K的序列训练的)
    3. 调整数据组合:训练期间对训练前数据组合进行了一些调整,以提高模型在特定下游任务上的性能,在预训练期间增加了非英语数据的百分比,以提高 Llama 3 的多语言性能。我们还对数学数据进行上采样以提高模型的数学推理性能,在预训练的后期阶段添加了更新的 Web 数据,以提高模型的知识截止率,并对后来被确定为质量较低的预训练数据的子集进行下采样
  2. 长上下文的预训练
    1. 在预训练的最后阶段,在长序列上进行训练,以支持高达 128K 令牌的上下文窗口
    2. 之前不会对长序列进行训练,因为自我注意层中的计算在序列长度上呈二次方增长
    3. 逐步增加支持的上下文长度,进行预训练,直到模型成功适应增加的上下文长度。
    4. 基于两个原则去慢慢增大上下文窗口
      1. 模型在短期评估中的性能是否完全恢复
      2. 模型完美地解决了该长度的 "大海捞针" 任务(从很长的序列里面提取关键信息)。
    5. 在 Llama 3 405B 预训练中,我们从最初的 8K 上下文窗口开始,到最终的 128K 上下文窗口结束,分六个阶段逐渐增加上下文长度。这个长上下文预训练阶段使用大约 800B 训练token执行。
  3. 退火训练(类似于拿评测集拟合下)
    1. 最后 40M token进行预训练期间,我们将学习率线性退火为 0,保持 128K token的上下文长度
    2. 在这个退火阶段,我们还调整了数据组合配比,增加高质量数据数学、代码、逻辑内容的影响(对非常高质量的数据源进行上采样)
    3. 最后,我们计算退火过程中模型检查点的平均值 ,以产生最终的预训练模型(最后的模型参数,是后期多个模型的参数平均值得到的,不是使用的最后一次检查点的模型, 竞赛刷磅的常用技巧)。 Polyak 平均(也称为 Polyak - Ruppert 平均)的核心思想是在模型训练过程中,对多个不同时间点(通常是在训练的后期阶段,即退火过程中)保存的模型参数进行加权平均,而不是直接使用最后一个训练得到的模型参数。这种方法有助于平滑模型在训练过程中的波动,减少噪声的影响,从而提高模型的泛化性能。

这里面能get到的一些经验总结:

  1. 数据清洗方面的一些经验,去重,配比等
  2. 使用退火来发现有价值的预训练数据
  3. 长文本的curriculum learning,逐步扩展
  4. 通过scaling law把FLOPs和下游任务效果关联起来,但是这个成本比较高,一般机构直接用结果就行了
  5. 基于和下游任务效果关联的scaling law选择data mix,同样是大力出奇迹,all you need is money
  6. checkpoint average,和苹果用到的model soup类似,是个值得关注的技巧

4.2 后训练

步骤:

  1. 人工标注数据训练Reward Model,用来评价一个<Prompt, answer>的数据质量
  2. 用预训练的模型生成若干个回答
  3. RM模型对回答进行拒绝采样,RM对生成的内容给予质量打分,选择得分最高的保留,作为SFT数据,其他的丢掉
  4. 上述SFT数据,再加上专门的增强代码、数学、逻辑能力的SFT数据混合,调整预训练模型得到SFT模型
  5. DPO训练调整LLM参数

上面完成了一个迭代轮次的Post-Training, 总共6轮迭代。

4.2.1 奖励模型

reward model(RM)是post-training中的一个重要部分。训练一个奖励模型(Reward Model):在预训练检查点的基础上,利用人类标注的偏好数据来训练奖励模型。

用预训练的语言模型权重作为初始化,再添加一个额外的线性层,将模型的输出映射为一个标量奖励值,该值代表了人类对整个文本序列(prompt + resp1 + resp2 + resp3)所体现的响应偏好程度。

和Llama-2相比,这次RM的一个变化是移除了训练时加入的margin term(用于把chosen和rejected response区分得更开),因为随着模型规模的增大,加入margin term收益越来越小了。

另一方面,同Llama-2一样,preference data中只有区分度比较大的数据对用于训练RM。

数据上,除了常规的chosen和rejected response之外,还引入了第三种 -- "edited response",即在chosen的基础上通过(人工)编辑,进一步提升这条response的质量。这样每条ranking sample就可能有3条response(edited > chosen > rejected)。

训练的时候,prompt和对应的多条随机打乱的response拼接在一起训练(prompt + resp_1 + resp_2 + resp_3),这和通常的做法,即每个response都拼接prompt有些不同(prompt + resp_1, prompt + resp_2, prompt + resp_3)。从结果上来看,都拼接到一起在accuracy上没有什么损失,而训练效率更高。

训练损失一般是采用对比损失, 好的resp的分数要比差的resp的分数高一个阈值。

推理的时候, 输入是<prompt, resp>, 模型会其打分, 表示符合人类偏好的程度。

4.2.2 监督微调(SFT)

训练好的RM模型会用于拒绝采样, 对生成的结果过滤。得到的高质量数据会和其他来源的SFT数据一起用于微调模型。

SFT数据主要有这几个来源:

  • 人工收集的prompt,以及对应的通过拒绝采样得到的response
  • 特定领域的合成数据(后面capacities部分会讲到)
  • 少量人类真实数据

处理细节:

  1. RS(拒绝采样)阶段,每个prompt会从"最新的chat模型中"采样K个回复(一般10-30个), 然后RM模型选出最佳回复
  2. 在靠后轮次的post-training中,RS引入了控制风格、格式、语气等特性的system prompt已更精细控制数据质量
  3. 不同领域(代码,推理,工具使用等)可能采用不同prompt

    SFT训练的时候, l r = 1 ∗ 1 0 − 5 lr=1*10^{-5} lr=1∗10−5, steps: 8.5 k − 9 k 8.5k-9k 8.5k−9k。 共使用了近2kw个QA对。

数据处理和质量控制:

  • post-training的前几轮中,研究人员发现数据中混入一些包含过量表情,感叹号值类的数据,因此,用专门规则对低质量的pattern进行清洗。
  • model-base的方法过滤低质量数据
    • 话题分类: llama 3 8B做粗粒度 & 细粒度的领域分类
    • 质量打分:
      • RM模型用于识别高质量回复,只有RM得分在前四分之一的数据认为高质量
      • 基于llama 3 checkpoint, 使用特定prompt(不同领域不同)进行多方面打分,得分高的认为高质量
      • 上面两者取并集
    • 难度打分:用了Instag和Llama模型打分两种方式衡量数据难度
    • 语义去重:用Roberta对对话聚类,每个类别中按quality score * difficulty进行排序, 然后只保留"和已选的高质量数据相似度小于阈值"的样本

4.2.3 DPO训练优化

DPO 是一种用于使模型的输出与用户偏好保持一致的方法。在后训练中,DPO 训练精选最新能力批次的数据。

在DPO阶段,会用在上一轮post-training得到的最佳模型收集偏好数据对,这样能使得偏好数据的分布和强化学习时的policy model更一致。

除了DPO以外,Meta也尝试了一些on-policy的方案,如PPO。但是相对来说,DPO消耗更少的计算资源,并且效果也更好,特别是在instruction following的能力上,所以还是选择在post-training使用DPO。

具体步骤如下:

  1. 数据准备:在预训练检查点的基础上,利用人类标注的偏好数据来训练直接模型。人类标注的结果会产生选中、拒绝和 "编辑后"(选中的响应经过编辑以使其更好)的响应,偏好顺序为编辑后 > 选中 > 拒绝,收集人类偏好数据,即包含选中响应和拒绝响应的数据集,但不需要对数据进行用于训练奖励模型的额外处理。
  2. 直接优化:直接使用偏好数据对模型进行优化,通过最小化特定的损失函数来更新模型参数,实现过程相对简洁。
    DOP训练中,使用的 L R = 1 ∗ 1 0 − 5 , b e t a = 0.1 LR=1*10^{-5}, beta=0.1 LR=1∗10−5,beta=0.1

训练中做了一些不同于标准做法的改动:

  1. 在DPO损失中, 把特殊token,比如起止符屏蔽,不用于计算loss,因为使用这些token计算loss会使得模型在生成时,出现如复读机或者在不合适的地方截断的情况。这可能就是因为chosen repsponse和rejected response同时包含的这些特殊token,让模型在训练时要同时增大和较小它们的likelihood,导致冲突。
  2. 除了DPO常用的loss, meta额外加入了NLL损失,模型生成的文本在语法、语义上更加准确和连贯,同时也能满足人类偏好的要求。

4.2.4 模型平均

在RM、SFT和DPO阶段,分别把"用不同版本的数据和超参训练得到模型"进行平均,以获得最终模型。

4.2.5 能力

在不同的具体领域上,Meta分别有一套方法,来提升对应的能力。能力这块不详细整理了,感兴趣的可以参考原论文或这篇文章

5. 应用

应用这块还没来得及仔细研究,可以先通过Ollma的方式把模型下载下来,然后直接跑起来。

https://ollama.com/,下载ollama(Ollama 是一个开源的大型语言模型服务工具,旨在帮助用户快速在本地运行大模型)

python 复制代码
# linux
curl -fsSL https://ollama.com/install.sh | sh

# 下载模型
ollama run llama3:8b

参考

相关推荐
188_djh4 小时前
# 10分钟了解DeepSeek,保姆级部署DeepSeek到WPS,实现AI赋能
人工智能·大语言模型·wps·ai技术·ai应用·deepseek·ai知识
DeepDriving2 天前
纯新手教程:用llama.cpp本地部署DeepSeek蒸馏模型
大语言模型·deepseek
风起晨曦2 天前
LLaMa Factory 安装
llama
运维开发王义杰2 天前
AI: Unsloth + Llama 3 微调实践,基于Colab
人工智能·llama
风起晨曦2 天前
(LLaMa Factory)大模型训练方法--预训练(Qwen2-0.5B)
llama
风起晨曦2 天前
(LLaMa Factory)大模型训练方法--监督微调(Qwen2-0.5B)
llama
子诚之2 天前
大模型Deepseek的使用_基于阿里云百炼和Chatbox
大语言模型
mygodalien3 天前
Win7编译GPU版llama.cpp部署deepseek-r1等大模型记录
人工智能·机器学习·chatgpt·llama
shandianchengzi3 天前
【BUG】LLM|Ubuntu 用 ollama 部署 DeepSeek 但没输出,llama 有输出
ubuntu·llm·bug·llama·ollama·deepseek