一、源码摘录
python
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
二、Transformer Block作用
这个TransformerBlock类的设计允许多个这样的块可以堆叠在一起,形成一个深度的Transformer网络。每个块的输出会被用作下一个块的输入,这样的设计使得Transformer能够处理非常复杂的序列建模任务。
三、代码注释

python
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // self.n_heads
python
# 一个多头注意力模块,用于对输入执行自注意力操作。
# 这个模块会计算输入的每个元素与其他元素之间的相互关系,并将这些关系用于更新输入。
self.attention = Attention(args)

python
# 一个前馈神经网络模块,包含一个 SwiGLU 激活函数 和一个线性层。
self.feed_forward = FeedForward(
dim = args.dim,
hidden_dim = 4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id

python
# RMS归一化层,对注意力的输出进行归一化。
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)

python
# RMS归一化层,对前馈神经网络的输出进行归一化。
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

python
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor, # 旋转矩阵
mask: Optional[torch.Tensor],
):
python
# 残差连接
# 将注意力模块的输出与原始输入x相加,形成一个残差连接。这是一种常见的深度学习技术,
# 可以帮助减少训练深层网络时的梯度消失问题。
h = x + self.attention.forward(
self.attention_norm(x), # 对输入x进行归一化,然后将归一化的x传递给注意力模块。
start_pos, # 开始的位置
freqs_cis, # 频率
mask,
)

python
# 对结果h进行归一化,然后传递给前馈神经网络模块。
# 前馈神经网络模块对其输入进行进一步的转换,并将输出与h相加,形成另一个残差连接。
out = h + self.feed_forward.forward(self.ffn_norm(h))

python
# 这个out将被用作下一个Transformer块的输入
return out