聊聊 从源码来看ChatGLM-6B的模型结构

基于ChatGLM-6B第一版,要注意还有ChatGLM2-6B以及ChatGLM3-6B

概述

ChatGLM是transformer架构的神经网络模型,因此从transformer结构入手,分析其源码结构。

transformer结构:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

位置编码

ChatGLM-6B的位置编码采用的旋转位置编码(RoPB)实现。其源码:

python 复制代码
class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

激活函数

ChatGLM-6B采用的激活函数是GeLU(高斯误差线性单元),其源码:

python 复制代码
@torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
                                       (1.0 + 0.044715 * x * x)))


def gelu(x):
    return gelu_impl(x)

编码器-解码器(encoder-decoder)

接下来就是编码器解码器结构,如何抓住模型源头来分析?可以从transformers的API入手:

python 复制代码
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to("cuda:1").eval()

print(mode)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

输出:

复制代码
ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(130528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=130528, bias=False)
)

从脑图的角度来梳理下其结构

其结构图表示如下:

将结构图与最开始的transformer结构图对比来看,两者还是比较符合的。

官方源码中标注了编码器与解码器是一体的,只需要配置参数即可切换为解码器。如下:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

相关推荐
小超同学你好16 小时前
Transformer 27. Vision Transformer(ViT):把图像当作「词序列」的编码器
人工智能·深度学习·transformer
高洁0121 小时前
计算机视觉实战:图像去噪模型训练与应用
人工智能·python·深度学习·机器学习·transformer
j_xxx404_1 天前
大语言模型 (LLM) 零基础入门:核心原理、训练机制与能力全解
人工智能·ai·transformer
<-->1 天前
Megatron(全称 Megatron-LM,由 NVIDIA 开发)和 DeepSpeed(由 Microsoft 开发)
人工智能·pytorch·python·深度学习·transformer
melonbo2 天前
RNN LSTM seq2seq 注意力机制 Transformer ,演化路径
rnn·lstm·transformer
爱编程的小吴2 天前
PyTorch+Transformer大模型入门到精通:LLM训练、推理、量化、部署全攻略
人工智能·pytorch·transformer
AI医影跨模态组学2 天前
Eur Radiol(IF=4.7)山西医科大学第一医院核磁影像科王效春等团队:基于Transformer增强型卷积神经网络的多中心MRI评估膀胱癌肌层浸润
人工智能·深度学习·论文·transformer·医学·医学影像
YuanDaima20482 天前
大语言模型生命周期全链路解析:从架构基石到高效推理
开发语言·人工智能·python·语言模型·架构·transformer
code_pgf2 天前
HLE测评LLM
transformer
code_pgf2 天前
LLM高难度测评体系-Humanity’s Last Exam(HLE)及与其它测评对比
transformer