llama源码学习·model.py[6]TransformerBlock类

一、源码摘录

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
 
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

二、Transformer Block作用

这个TransformerBlock类的设计允许多个这样的块可以堆叠在一起,形成一个深度的Transformer网络。每个块的输出会被用作下一个块的输入,这样的设计使得Transformer能够处理非常复杂的序列建模任务。

三、代码注释

python 复制代码
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // self.n_heads
python 复制代码
        # 一个多头注意力模块,用于对输入执行自注意力操作。
        # 这个模块会计算输入的每个元素与其他元素之间的相互关系,并将这些关系用于更新输入。
        self.attention = Attention(args)
python 复制代码
        # 一个前馈神经网络模块,包含一个 SwiGLU 激活函数 和一个线性层。
        self.feed_forward = FeedForward(
            dim = args.dim,
            hidden_dim = 4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
python 复制代码
        # RMS归一化层,对注意力的输出进行归一化。
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
        # RMS归一化层,对前馈神经网络的输出进行归一化。
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
python 复制代码
    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor, # 旋转矩阵
        mask: Optional[torch.Tensor],
    ):
python 复制代码
        # 残差连接
        # 将注意力模块的输出与原始输入x相加,形成一个残差连接。这是一种常见的深度学习技术,
        # 可以帮助减少训练深层网络时的梯度消失问题。
        h = x + self.attention.forward(
            self.attention_norm(x), # 对输入x进行归一化,然后将归一化的x传递给注意力模块。
            start_pos, # 开始的位置
            freqs_cis, # 频率
            mask,
        )
python 复制代码
        # 对结果h进行归一化,然后传递给前馈神经网络模块。
        # 前馈神经网络模块对其输入进行进一步的转换,并将输出与h相加,形成另一个残差连接。
        out = h + self.feed_forward.forward(self.ffn_norm(h))
python 复制代码
        # 这个out将被用作下一个Transformer块的输入
        return out
相关推荐
阿阳微客4 小时前
Steam 搬砖项目深度拆解:从抵触到真香的转型之路
前端·笔记·学习·游戏
Chef_Chen9 小时前
从0开始学习R语言--Day18--分类变量关联性检验
学习
键盘敲没电9 小时前
【IOS】GCD学习
学习·ios·objective-c·xcode
海的诗篇_10 小时前
前端开发面试题总结-JavaScript篇(一)
开发语言·前端·javascript·学习·面试
AgilityBaby10 小时前
UE5 2D角色PaperZD插件动画状态机学习笔记
笔记·学习·ue5
AgilityBaby10 小时前
UE5 创建2D角色帧动画学习笔记
笔记·学习·ue5
武昌库里写JAVA11 小时前
iview Switch Tabs TabPane 使用提示Maximum call stack size exceeded堆栈溢出
java·开发语言·spring boot·学习·课程设计
一弓虽12 小时前
git 学习
git·学习
Moonnnn.14 小时前
【单片机期末】串行口循环缓冲区发送
笔记·单片机·嵌入式硬件·学习
viperrrrrrrrrr715 小时前
大数据学习(131)-Hive数据分析函数总结
大数据·hive·学习