llama源码学习·model.py[6]TransformerBlock类

一、源码摘录

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
 
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

二、Transformer Block作用

这个TransformerBlock类的设计允许多个这样的块可以堆叠在一起,形成一个深度的Transformer网络。每个块的输出会被用作下一个块的输入,这样的设计使得Transformer能够处理非常复杂的序列建模任务。

三、代码注释

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // self.n_heads
python 复制代码
        # 一个多头注意力模块,用于对输入执行自注意力操作。
        # 这个模块会计算输入的每个元素与其他元素之间的相互关系,并将这些关系用于更新输入。
        self.attention = Attention(args)
python 复制代码
        # 一个前馈神经网络模块,包含一个 SwiGLU 激活函数 和一个线性层。
        self.feed_forward = FeedForward(
            dim = args.dim,
            hidden_dim = 4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
python 复制代码
        # RMS归一化层,对注意力的输出进行归一化。
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
        # RMS归一化层,对前馈神经网络的输出进行归一化。
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor, # 旋转矩阵
        mask: Optional[torch.Tensor],
    ):
python 复制代码
        # 残差连接
        # 将注意力模块的输出与原始输入x相加,形成一个残差连接。这是一种常见的深度学习技术,
        # 可以帮助减少训练深层网络时的梯度消失问题。
        h = x + self.attention.forward(
            self.attention_norm(x), # 对输入x进行归一化,然后将归一化的x传递给注意力模块。
            start_pos, # 开始的位置
            freqs_cis, # 频率
            mask,
        )
python 复制代码
        # 对结果h进行归一化,然后传递给前馈神经网络模块。
        # 前馈神经网络模块对其输入进行进一步的转换,并将输出与h相加,形成另一个残差连接。
        out = h + self.feed_forward.forward(self.ffn_norm(h))
python 复制代码
        # 这个out将被用作下一个Transformer块的输入
        return out
相关推荐
charlie1145141911 小时前
使用 Poetry + VS Code 创建你的第一个 Flask 工程
开发语言·笔记·后端·python·学习·flask·教程
菥菥爱嘻嘻1 小时前
langchain学习-RAG+prompt+OutPutParse
学习·langchain·prompt
Jing_jing_X1 小时前
ChatGPT 四种模式:普通对话、推理思考、深度研究、学习模式有什么区别?
人工智能·学习·chatgpt
AA陈超1 小时前
Lyra项目中的输入系统
c++·笔记·学习·游戏·ue5·lyra
AA陈超1 小时前
ASC学习笔记0027:直接设置属性的基础值,而不会影响当前正在生效的任何修饰符(Modifiers)
c++·笔记·学习·ue5·虚幻引擎
doubao362 小时前
如何在海量文献中高效筛选有价值信息
人工智能·学习·自然语言处理·aigc·ai工具·ai检索
烤麻辣烫2 小时前
AI(新手)
人工智能·学习·机器学习·ai编程
青衫码上行2 小时前
【Java Web学习 | 第14篇】JavaScript(8) -正则表达式
java·前端·javascript·学习·正则表达式
哲Zheᗜe༘3 小时前
学习Ansible Playbook 核心语法
网络·学习·ansible
('-')3 小时前
《从根上理解MySQL是怎样运行的》第三章学习笔记
笔记·学习·mysql