llama源码学习·model.py[6]TransformerBlock类

一、源码摘录

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
 
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

二、Transformer Block作用

这个TransformerBlock类的设计允许多个这样的块可以堆叠在一起,形成一个深度的Transformer网络。每个块的输出会被用作下一个块的输入,这样的设计使得Transformer能够处理非常复杂的序列建模任务。

三、代码注释

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // self.n_heads
python 复制代码
        # 一个多头注意力模块,用于对输入执行自注意力操作。
        # 这个模块会计算输入的每个元素与其他元素之间的相互关系,并将这些关系用于更新输入。
        self.attention = Attention(args)
python 复制代码
        # 一个前馈神经网络模块,包含一个 SwiGLU 激活函数 和一个线性层。
        self.feed_forward = FeedForward(
            dim = args.dim,
            hidden_dim = 4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
python 复制代码
        # RMS归一化层,对注意力的输出进行归一化。
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
        # RMS归一化层,对前馈神经网络的输出进行归一化。
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor, # 旋转矩阵
        mask: Optional[torch.Tensor],
    ):
python 复制代码
        # 残差连接
        # 将注意力模块的输出与原始输入x相加,形成一个残差连接。这是一种常见的深度学习技术,
        # 可以帮助减少训练深层网络时的梯度消失问题。
        h = x + self.attention.forward(
            self.attention_norm(x), # 对输入x进行归一化,然后将归一化的x传递给注意力模块。
            start_pos, # 开始的位置
            freqs_cis, # 频率
            mask,
        )
python 复制代码
        # 对结果h进行归一化,然后传递给前馈神经网络模块。
        # 前馈神经网络模块对其输入进行进一步的转换,并将输出与h相加,形成另一个残差连接。
        out = h + self.feed_forward.forward(self.ffn_norm(h))
python 复制代码
        # 这个out将被用作下一个Transformer块的输入
        return out
相关推荐
song_ly00144 分钟前
深入理解软件测试覆盖率:从概念到实践
笔记·学习·测试
DIY机器人工房1 小时前
[6-2] 定时器定时中断&定时器外部时钟 江协科技学习笔记(41个知识点)
笔记·stm32·单片机·学习·江协科技
海尔辛2 小时前
学习黑客5 分钟小白弄懂Windows Desktop GUI
windows·学习
烟雨迷3 小时前
Linux环境基础开发工具的使用(yum、vim、gcc、g++、gdb、make/Makefile)
linux·服务器·学习·编辑器·vim
@十八子德月生4 小时前
8天Python从入门到精通【itheima】-1~5
大数据·开发语言·python·学习
Clockwiseee5 小时前
文件上传总结
运维·服务器·学习·文件上传
苜柠6 小时前
Wpf学习片段
学习
欢乐熊嵌入式编程6 小时前
智能手表固件升级 OTA 策略文档初稿
嵌入式硬件·学习·智能手表
起床学FPGA6 小时前
异步FIFO的学习
学习·fpga开发
依年南台7 小时前
搭建大数据学习的平台
大数据·学习