llama源码学习·model.py[6]TransformerBlock类

一、源码摘录

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
 
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

二、Transformer Block作用

这个TransformerBlock类的设计允许多个这样的块可以堆叠在一起,形成一个深度的Transformer网络。每个块的输出会被用作下一个块的输入,这样的设计使得Transformer能够处理非常复杂的序列建模任务。

三、代码注释

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // self.n_heads
python 复制代码
        # 一个多头注意力模块,用于对输入执行自注意力操作。
        # 这个模块会计算输入的每个元素与其他元素之间的相互关系,并将这些关系用于更新输入。
        self.attention = Attention(args)
python 复制代码
        # 一个前馈神经网络模块,包含一个 SwiGLU 激活函数 和一个线性层。
        self.feed_forward = FeedForward(
            dim = args.dim,
            hidden_dim = 4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
python 复制代码
        # RMS归一化层,对注意力的输出进行归一化。
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
        # RMS归一化层,对前馈神经网络的输出进行归一化。
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor, # 旋转矩阵
        mask: Optional[torch.Tensor],
    ):
python 复制代码
        # 残差连接
        # 将注意力模块的输出与原始输入x相加,形成一个残差连接。这是一种常见的深度学习技术,
        # 可以帮助减少训练深层网络时的梯度消失问题。
        h = x + self.attention.forward(
            self.attention_norm(x), # 对输入x进行归一化,然后将归一化的x传递给注意力模块。
            start_pos, # 开始的位置
            freqs_cis, # 频率
            mask,
        )
python 复制代码
        # 对结果h进行归一化,然后传递给前馈神经网络模块。
        # 前馈神经网络模块对其输入进行进一步的转换,并将输出与h相加,形成另一个残差连接。
        out = h + self.feed_forward.forward(self.ffn_norm(h))
python 复制代码
        # 这个out将被用作下一个Transformer块的输入
        return out
相关推荐
DigitalOcean3 天前
DigitalOcean Gradient AI 推理云平台原生集成 LlamaIndex
llama
西岸行者9 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意9 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码9 天前
嵌入式学习路线
学习
毛小茛9 天前
计算机系统概论——校验码
学习
babe小鑫9 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms9 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下9 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。9 天前
2026.2.25监控学习
学习
im_AMBER9 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode