Llama开源代码详细解读(3)

expand_mask模块

python 复制代码
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

这个函数接收一个张量mask,并将其变换为特定的形状。输入三个参数分别为:mask:大小为[bsz, seq_len]。dtype:数据类型。tgt_len:目标序列长度。以下是函数的运行方式。

获取mask参数

python 复制代码
  bsz, src_len = mask.size()
  • .size()函数获取了mask张量的行数、列数,即bsz,src_len。

确定目标序列长度

python 复制代码
tgt_len = tgt_len if tgt_len is not None else src_len
  • 如果tgt_len没有被指定,则赋值为src_len。

扩展掩码

python 复制代码
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  • mask[:, None, None, :]将mask的维度从[bsz, seq_len]拓展为[bsz,1,1,seq_len],expand(bsz, 1, tgt_len, src_len)将拓展后的矩阵继续拓展为[bsz,1,tgt_len,src_len),to(dtype)转换为指定的数据类型。

生成反转掩码

python 复制代码
inverted_mask = 1.0 - expanded_mask

将掩码中0和1的位置互换。

填充反转掩码

python 复制代码
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

inverted_mask.to(torch.bool)将反转掩码转换为布尔类型

masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)用dtype这个类型的最小值填充其中为true的位置。

返回一个经过填充处理的反转掩码张量,形状为 [bsz, 1, tgt_len, src_len],数据类型为 dtype。

RMSNorm归一化模块

python 复制代码
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        # eps一个很小的数,用于避免除零错误
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # weight是一个可训练的参数,初始化为一个大小为 hidden_size 的全 1 张量
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        # 存储hidden_states的数据类型
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # hidden_states一般是batch size * sequence length * hidden size,这里的mean是按最后一维取平均
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * hidden_states).to(input_dtype)

这段代码定义了一个RMSNorm类,用于实现归一化。

构造函数__init__()

python 复制代码
def __init__(self, hidden_size, eps=1e-6):
    super().__init__()
    self.weight = nn.Parameter(torch.ones(hidden_size))
    self.variance_epsilon = eps
  • hidden_size表示隐藏层的维度大小,eps表示一个很小的数,用于防止除零错误,self.weight表示为一个可训练的权重参数,初始化为一个和hidden_size大小一致的全1张量。

前向传播forward()

python 复制代码
def forward(self, hidden_states):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    return (self.weight * hidden_states).to(input_dtype)
  • input_dtype储存隐藏层的数据类型
  • hidden_states = hidden_states.to(torch.float32)将数据类型转换为bf32,确保计算的稳定性
  • variance = hidden_states.pow(2).mean(-1, keepdim=True)计算隐藏层在最后一层的方差,具体步骤是对每个元素平方再取平均。
  • hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)计算反向标准差(1/标准差),并与hidden_state元素逐个相乘,从而实现归一化。
  • return (self.weight * hidden_states).to(input_dtype)将归一化后的隐藏层乘以权重,并转换回输入的数据类型。
相关推荐
南宫理的日知录11 分钟前
99、Python并发编程:多线程的问题、临界资源以及同步机制
开发语言·python·学习·编程学习
coberup21 分钟前
django Forbidden (403)错误解决方法
python·django·403错误
wangyue430 分钟前
c# 深度模型入门
深度学习
川石课堂软件测试43 分钟前
性能测试|docker容器下搭建JMeter+Grafana+Influxdb监控可视化平台
运维·javascript·深度学习·jmeter·docker·容器·grafana
985小水博一枚呀1 小时前
【深度学习滑坡制图|论文解读3】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer
龙哥说跨境1 小时前
如何利用指纹浏览器爬虫绕过Cloudflare的防护?
服务器·网络·python·网络爬虫
985小水博一枚呀1 小时前
【深度学习滑坡制图|论文解读2】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer·迁移学习
小白学大数据1 小时前
正则表达式在Kotlin中的应用:提取图片链接
开发语言·python·selenium·正则表达式·kotlin
flashman9111 小时前
python在word中插入图片
python·microsoft·自动化·word
菜鸟的人工智能之路1 小时前
桑基图在医学数据分析中的更复杂应用示例
python·数据分析·健康医疗