【HuggingFace Transformers】LlamaModel源码解析

LlamaModel源码解析

  • [1. LlamaModel 介绍](#1. LlamaModel 介绍)
  • [2. LlamaModel类 源码解析](#2. LlamaModel类 源码解析)
  • [3. 4维因果注意力掩码生成](#3. 4维因果注意力掩码生成)

1. LlamaModel 介绍

LlamaModel 是一个基于 Transformer 架构的解码器模型,用于自然语言处理任务。它是 Meta 的 LLaMA (Large Language Model Meta AI) 系列的一部分,设计用于生成任务和自回归文本生成。它通过解码器层位置编码归一化层来处理输入序列,并提供了对缓存和注意力机制的支持。它在大规模自然语言生成任务中表现出色,并能够处理复杂的序列依赖关系。其结构如下:

2. LlamaModel类 源码解析

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:36

import torch

from typing import List, Optional, Tuple, Union
from torch import nn
from transformers import LlamaPreTrainedModel, LlamaConfig, Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaRotaryEmbedding, LLAMA_START_DOCSTRING, LLAMA_INPUTS_DOCSTRING
from transformers.utils import logging, add_start_docstrings, add_start_docstrings_to_model_forward

logger = logging.get_logger(__name__)


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id  # 设置 padding token 的索引
        self.vocab_size = config.vocab_size  # 设置词汇表的大小

        # 1. 定义嵌入层:将输入的 token 转换为隐状态向量。它的大小为 vocab_size x hidden_size
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        # 2. 定义解码层:使用 nn.ModuleList 定义了一系列的 LlamaDecoderLayer,解码层的数量由 config.num_hidden_layers 决定。
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        # 3. 定义规范化层:使用 LlamaRMSNorm 进行层归一化处理
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        # 4. 定义旋转嵌入:使用 LlamaRotaryEmbedding 实现旋转嵌入,用于改进注意力机制中的位置编码
        self.rotary_emb = LlamaRotaryEmbedding(config=config)

        # 梯度检查点:用于在训练过程中节省内存的功能,默认为 False。
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()  # 初始化权重并进行最终处理

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入的 token ID
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码
        position_ids: Optional[torch.LongTensor] = None,  # 位置 ID
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,  # 缓存的 key-value 对
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入嵌入
        use_cache: Optional[bool] = None,  # 是否使用缓存
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态
        return_dict: Optional[bool] = None,  # 是否返回字典类型的输出
        cache_position: Optional[torch.LongTensor] = None,  # 缓存位置
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        # -----------------------------1. 初始化一系列输入变量,用于 decoder_layer 的前向传播计算-------------------------------
        # 初始化 output_attentions / output_hidden_states / use_cache / return_dict
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 输入验证:确保 input_ids 和 inputs_embeds 不能同时被指定,但必须指定其中之一。
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        # 如果训练过程中使用梯度检查点且使用缓存,则发出警告,并禁用缓存
        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        # 嵌入计算:根据 input_ids 计算 inputs_embeds,如果已经提供 inputs_embeds,则使用该值。
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果使用旧的缓存格式,将其转换为新的格式
        return_legacy_cache = False
        if (
            use_cache and not isinstance(past_key_values, Cache) and not self.training
        ):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
                "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
            )

        # 如果没有指定缓存位置,则根据已处理的 token 数量设置缓存位置
        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # 如果没有指定位置 ID,则使用缓存位置作为位置 ID
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        # 更新因果掩码,用于确保解码器只看见当前时间步之前的 token
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        # 位置编码:生成位置编码,用于结合输入嵌入进行旋转嵌入
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # -----------------------------2. 初始化一系列输出变量,用于保存 decoder_layer 前向传播计算的输出结果-------------------------------
        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        # -----------------------------3. 依次通过每一层,执行 decoder_layer 前向传播计算,同时更新相应的变量值-------------------------------
        # 遍历每一层解码器层,并将输入和注意力掩码传递给每一层
        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 如果使用梯度检查点,则使用特殊方法处理解码器层
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                # 否则,直接调用解码器层的前向传播方法
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

            # 更新隐藏状态
            hidden_states = layer_outputs[0]

            # 如果使用缓存,则更新缓存
            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            # 如果输出注意力权重,则将其添加到所有注意力权重的列表中
            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # -----------------------------4. 获取相应的输出变量,根据条件进行处理后返回结果-------------------------------
        # 最后一层的隐藏状态经过归一化处理
        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        # 如果输出隐藏状态,则将最终隐藏状态添加到元组中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 根据是否使用缓存,返回下一个缓存
        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        # 输出处理
        # 如果不返回字典,则返回元组类型的输出
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        # 返回 BaseModelOutputWithPast 类型的字典,包括最后的隐藏状态、缓存、隐藏状态和注意力权重
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    # 更新因果掩码方法,确保模型只能看到当前时间步之前的 token
    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        # TODO: 自 torch==2.2.0 以来,传递给模型的 `attention_mask` 在生成过程中是 2D 的,并且长度是动态的,即使使用了静态 KV 缓存。
        # 这会导致 torch.compile 在每个解码步骤中重新捕获 cudagraphs,因为形状是动态的,这是非常慢的。
        # 一种解决方法是使用 `@torch.compiler.disable`,但这会阻止使用 `fullgraph=True`。
        # 更多背景信息可以参考 https://github.com/huggingface/transformers/pull/29114

        # --------------------------flash_attention_2 注意力和 sdpa 注意力的配置判断-------------------------------
        # 如果配置的注意力实现是 "flash_attention_2"
        if self.config._attn_implementation == "flash_attention_2":
            # 如果提供了 attention_mask 且其中包含 0.0,则直接返回 attention_mask;否则返回 None,表示不需要额外处理
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # 对于 SDPA(Scaled Dot-Product Attention),我们将依赖它的 `is_causal` 参数而不是 `attn_mask` 参数,以便分派到 Flash Attention 2 实现。这种特性与静态缓存不兼容,因为 SDPA 无法推断出注意力掩码。
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0  # 获取已经看到的 token 数量
        using_static_cache = isinstance(past_key_values, StaticCache)  # 检查是否使用静态缓存

        # 当 output_attentions 为 True 时,SDPA 实现的前向传播方法会调用 eager(迫切)实现的前向传播方法
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            # 检查是否可以忽略 SDPA 的因果掩码
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        # -------------------------初始化一系列输入变量,用于 4d_causal_attention_mask 的计算--------------------------
        dtype, device = input_tensor.dtype, input_tensor.device  # 获取输入张量的 dtype 和设备信息
        min_dtype = torch.finfo(dtype).min  # 获取 dtype 的最小值,用于填充掩码
        sequence_length = input_tensor.shape[1]  # 获取序列长度

        # 如果使用静态缓存,target_length为缓存中已看到的最大长度
        if using_static_cache:
            target_length = past_key_values.get_max_length()
        else:
            # 否则target_length为注意力掩码的最后一个维度长度,或者已看到的 token 数量加上当前序列长度再加 1
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        # 提供的 `attention_mask` 是 2D 的,我们在这里生成一个因果掩码(4D 的)。
        causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            min_dtype=min_dtype,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        # --------------------输出结果 causal_mask 的进一步操作(可选)----------------------
        # 如果配置的注意力实现是 "sdpa",且 `attention_mask` 存在,并且设备类型为 CUDA 且不输出注意力权重
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            # 在因果掩码中完全掩盖的行中,使所有 token 可见,例如使用左填充时的相关第一行。这是为了适应 F.scaled_dot_product_attention 的内存高效路径。详情请参考:https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        # ---------------------返回最终的因果掩码------------------------------------------
        return causal_mask

3. 4维因果注意力掩码生成

_prepare_4d_causal_attention_mask_with_cache_position 函数用于生成一个4维的因果注意力掩码(causal attention mask),适用于生成任务中的自回归解码器。这个掩码有助于确保模型在生成序列时仅能基于当前和之前的 token,而不查看未来的 token。以下是对该函数的源码解释:

python 复制代码
# -*- coding: utf-8 -*-
# @time: 2024/8/28 14:36

# 生成4D的因果注意力掩码方法
def _prepare_4d_causal_attention_mask_with_cache_position(
    attention_mask: torch.Tensor,
    sequence_length: int,
    target_length: int,
    dtype: torch.dtype,
    device: torch.device,
    min_dtype: float,
    cache_position: torch.Tensor,
    batch_size: int,
):
    """
    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
    `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

    Args:
        attention_mask (`torch.Tensor`):
            A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
        sequence_length (`int`):
            The sequence length being processed.
        target_length (`int`):
            The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
        dtype (`torch.dtype`):
            The dtype to use for the 4D attention mask.
        device (`torch.device`):
            The device to plcae the 4D attention mask on.
        min_dtype (`float`):
            The minimum value representable with the dtype `dtype`.
        cache_position (`torch.Tensor`):
            Indices depicting the position of the input sequence tokens in the sequence.
        batch_size (`torch.Tensor`):
            Batch size.
    """
    # 1. 检查掩码维度:如果输入的 attention_mask 是4维的,直接使用它作为 causal_mask,因为它已经是所需的形式。
    if attention_mask is not None and attention_mask.dim() == 4:
        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
        causal_mask = attention_mask
    else:
        # 2. 生成默认4D掩码:创建一个全0的2D掩码,形状为 (sequence_length, target_length),并用 min_dtype 填充。
        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        # 如果 sequence_length 不等于1,将掩码的上三角部分设置为 min_dtype,以创建因果掩码(上三角矩阵,确保每个位置只能关注自己及之前的位置)。
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        # 3. 调整掩码以考虑缓存位置:根据 cache_position 计算掩码的位置,causal_mask 的值将根据 cache_position 进行调整,以便正确处理缓存。
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        # 4. 扩展掩码以适应批处理:将掩码扩展到4维,并适应批处理大小 (batch_size),最终的形状为 (batch_size, 1, sequence_length, target_length)。
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        # 5. 融合外部注意力掩码:如果提供了外部的 attention_mask,则将其与生成的 causal_mask 结合。通过掩码位置设置正确的填充,以确保只关注有效位置。
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )

    return causal_mask
相关推荐
GIOTTO情2 分钟前
媒介宣发的技术革命:Infoseek如何用AI重构企业传播全链路
大数据·人工智能·重构
阿里云大数据AI技术11 分钟前
云栖实录 | 从多模态数据到 Physical AI,PAI 助力客户快速启动 Physical AI 实践
人工智能
小关会打代码18 分钟前
计算机视觉进阶教学之颜色识别
人工智能·计算机视觉
IT小哥哥呀24 分钟前
基于深度学习的数字图像分类实验与分析
人工智能·深度学习·分类
机器之心1 小时前
VAE时代终结?谢赛宁团队「RAE」登场,表征自编码器或成DiT训练新基石
人工智能·openai
机器之心1 小时前
Sutton判定「LLM是死胡同」后,新访谈揭示AI困境
人工智能·openai
大模型真好玩1 小时前
低代码Agent开发框架使用指南(四)—Coze大模型和插件参数配置最佳实践
人工智能·agent·coze
jerryinwuhan1 小时前
基于大语言模型(LLM)的城市时间、空间与情感交织分析:面向智能城市的情感动态预测与空间优化
人工智能·语言模型·自然语言处理
落雪财神意1 小时前
股指10月想法
大数据·人工智能·金融·区块链·期股
中杯可乐多加冰1 小时前
无代码开发实践|基于业务流能力快速开发市场监管系统,实现投诉处理快速响应
人工智能·低代码