llama源码学习·model.py[6]TransformerBlock类

一、源码摘录

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
 
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

二、Transformer Block作用

这个TransformerBlock类的设计允许多个这样的块可以堆叠在一起,形成一个深度的Transformer网络。每个块的输出会被用作下一个块的输入,这样的设计使得Transformer能够处理非常复杂的序列建模任务。

三、代码注释

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // self.n_heads
python 复制代码
        # 一个多头注意力模块,用于对输入执行自注意力操作。
        # 这个模块会计算输入的每个元素与其他元素之间的相互关系,并将这些关系用于更新输入。
        self.attention = Attention(args)
python 复制代码
        # 一个前馈神经网络模块,包含一个 SwiGLU 激活函数 和一个线性层。
        self.feed_forward = FeedForward(
            dim = args.dim,
            hidden_dim = 4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
python 复制代码
        # RMS归一化层,对注意力的输出进行归一化。
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
        # RMS归一化层,对前馈神经网络的输出进行归一化。
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor, # 旋转矩阵
        mask: Optional[torch.Tensor],
    ):
python 复制代码
        # 残差连接
        # 将注意力模块的输出与原始输入x相加,形成一个残差连接。这是一种常见的深度学习技术,
        # 可以帮助减少训练深层网络时的梯度消失问题。
        h = x + self.attention.forward(
            self.attention_norm(x), # 对输入x进行归一化,然后将归一化的x传递给注意力模块。
            start_pos, # 开始的位置
            freqs_cis, # 频率
            mask,
        )
python 复制代码
        # 对结果h进行归一化,然后传递给前馈神经网络模块。
        # 前馈神经网络模块对其输入进行进一步的转换,并将输出与h相加,形成另一个残差连接。
        out = h + self.feed_forward.forward(self.ffn_norm(h))
python 复制代码
        # 这个out将被用作下一个Transformer块的输入
        return out
相关推荐
red_redemption3 小时前
自由学习记录(88)
学习
百分百题库APP7 小时前
江苏安全员 A 证 “安全生产管理” 核心考点
学习·考试·题库·考证
go&Python8 小时前
检索模型与RAG
开发语言·python·llama
霜绛12 小时前
Unity笔记(六)——Mathf、三角函数、坐标系、向量
笔记·学习·unity·游戏引擎
long31613 小时前
代理设计模式
java·学习·程序人生·设计模式·代理模式
MThinker13 小时前
14.examples\01-Micropython-Basics\demo_yield.py 加强版
python·学习·智能硬件·micropython·canmv·k230
月盈缺13 小时前
学习嵌入式的第二十五天——哈希表和内核链表
学习·链表·散列表
好奇龙猫14 小时前
日语学习-日语知识点小记-构建基础-JLPT-N3阶段(19):文法复习+单词第7回1
学习
ts码农15 小时前
blazor 学习笔记--vscode debug
笔记·vscode·学习
牛奶yu茶15 小时前
Python学习笔记之(二)变量和简单的数据类型
笔记·python·学习