聊聊 从源码来看ChatGLM-6B的模型结构

基于ChatGLM-6B第一版,要注意还有ChatGLM2-6B以及ChatGLM3-6B

概述

ChatGLM是transformer架构的神经网络模型,因此从transformer结构入手,分析其源码结构。

transformer结构:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

位置编码

ChatGLM-6B的位置编码采用的旋转位置编码(RoPB)实现。其源码:

python 复制代码
class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

激活函数

ChatGLM-6B采用的激活函数是GeLU(高斯误差线性单元),其源码:

python 复制代码
@torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
                                       (1.0 + 0.044715 * x * x)))


def gelu(x):
    return gelu_impl(x)

编码器-解码器(encoder-decoder)

接下来就是编码器解码器结构,如何抓住模型源头来分析?可以从transformers的API入手:

python 复制代码
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to("cuda:1").eval()

print(mode)

## 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

输出:

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(130528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=130528, bias=False)
)

从脑图的角度来梳理下其结构

其结构图表示如下:

将结构图与最开始的transformer结构图对比来看,两者还是比较符合的。

官方源码中标注了编码器与解码器是一体的,只需要配置参数即可切换为解码器。如下:

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

相关推荐
rommel rain7 小时前
SpecInfer论文阅读
人工智能·语言模型·transformer
Just Jump21 小时前
机器翻译基础与模型 之三:基于自注意力的模型
自然语言处理·transformer·机器翻译
cv君1 天前
视频修复技术和实时在线处理
深度学习·音视频·transformer·视频修复
机器学习之心1 天前
POD-Transformer多变量回归预测(Matlab)
matlab·回归·transformer·pod-transformer
regret~2 天前
【论文笔记】LoFLAT: Local Feature Matching using Focused Linear Attention Transformer
论文阅读·深度学习·transformer
迪菲赫尔曼2 天前
即插即用篇 | YOLOv11 引入高效的直方图Transformer模块 | 突破天气障碍:Histoformer引领高效图像修复新路径
人工智能·深度学习·yolo·目标检测·计算机视觉·transformer·注意力机制
Hqst 网络变压器 Andy3 天前
How to connect a 2.5G network transformer to an RJ45 network port and chip
深度学习·5g·transformer
xianghan收藏册3 天前
LLM文档对话 —— pdf解析关键问题
人工智能·深度学习·自然语言处理·chatgpt·transformer
机器白学4 天前
从零开始使用GOT-OCR2.0——多模态通用型OCR(非常具有潜力的开源OCR项目):项目环境安装配置 + 测试使用
ocr·transformer·多模态·视觉语言大模型
_Randy_5 天前
Transformer
人工智能·深度学习·transformer