以bert为例,了解Lora是如何添加到模型中的

以bert为例,了解Lora是如何添加到模型中的

本文以bert为例,对比了添加Lora模块前后的网络结构图
说明:

  • 1.为了加快速度,将bert修改为一层
  • 2.lora只加到intermediate.dense,方便对比
  • 3.使用了几种不同的可视化方式(onnx可视化,torchviz图,torch.fx可视化,tensorboard可视化)

可参考的点:

  • 1.peft使用
  • 2.几种不同的pytorch模型可视化方法

一.效果图

1.torch.fx可视化

A.添加前

B.添加后

2.onnx可视化

A.添加前

B.添加后

3.tensorboard可视化

A.添加前

B.添加后

二.复现步骤

1.生成配置文件(num_hidden_layers=1)

bash 复制代码
tee ./config.json <<-'EOF'
{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 1,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 21128
}
EOF

2.运行测试脚本

bash 复制代码
tee bert_lora.py <<-'EOF'
import time
import os
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.init as init
import time
import numpy as np
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter
from torch._functorch.partitioners import draw_graph

def onnx_infer_shape(onnx_path):
    import onnx
    onnx_model  = onnx.load_model(onnx_path)
    new_onnx= onnx.shape_inference.infer_shapes(onnx_model)
    onnx.save_model(new_onnx, onnx_path)

def get_model():
    torch.manual_seed(1)
    from transformers import AutoModelForMaskedLM,BertConfig
    config=BertConfig.from_pretrained("./config.json")
    model = AutoModelForMaskedLM.from_config(config)
    return model,config

def my_compiler(fx_module: torch.fx.GraphModule, _):
    draw_graph(fx_module, f"bert.{time.time()}.svg")
    return fx_module.forward

if __name__ == "__main__":

    model,config=get_model()
    model.eval()
    input_tokens=torch.randint(0,config.vocab_size,(1,128))
    
    # 一.原始模型
    # 1.onnx可视化
    torch.onnx.export(model,input_tokens,
                  "bert_base.onnx",
                  export_params=False,
                  opset_version=11,
                  do_constant_folding=True)
    onnx_infer_shape("bert_base.onnx")
    
    # 2.torchviz图
    output = model(input_tokens)
    logits = output.logits
    viz = make_dot(logits, params=dict(model.named_parameters()))
    viz.render("bert_base", view=False)
    
    # 3.torch.fx可视化
    compiled_model = torch.compile(model, backend=my_compiler)
    output = compiled_model(input_tokens)

    # 4.tensorboard可视化
    writer = SummaryWriter('./runs')
    writer.add_graph(model, input_to_model = input_tokens,use_strict_trace=False)
    writer.close()
    
    # 二.Lora模型
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=True,
        r=8,
        lora_alpha=32,
        target_modules=['intermediate.dense'],
        lora_dropout=0.1,
    )
    lora_model = get_peft_model(model, peft_config)
    lora_model.eval()
    torch.onnx.export(lora_model,input_tokens,
                      "bert_base_lora_inference_mode.onnx",
                      export_params=False,
                      opset_version=11,
                      do_constant_folding=True)
    onnx_infer_shape("bert_base_lora_inference_mode.onnx")

    compiled_model = torch.compile(lora_model, backend=my_compiler)
    output = compiled_model(input_tokens)

    writer = SummaryWriter('./runs_lora')
    writer.add_graph(lora_model, input_to_model = input_tokens,use_strict_trace=False)
    writer.close()
EOF

# 安装依赖
apt install graphviz -y
pip install torchviz
pip install pydot

# 运行测试程序
python bert_lora.py
相关推荐
云知谷2 小时前
【C++基本功】C++适合做什么,哪些领域适合哪些领域不适合?
c语言·开发语言·c++·人工智能·团队开发
rit84324992 小时前
基于MATLAB实现基于距离的离群点检测算法
人工智能·算法·matlab
初学小刘3 小时前
深度学习:从图片数据到模型训练(十分类)
人工智能·深度学习
递归不收敛4 小时前
大语言模型(LLM)入门笔记:嵌入向量与位置信息
人工智能·笔记·语言模型
之墨_4 小时前
【大语言模型】—— 自注意力机制及其变体(交叉注意力、因果注意力、多头注意力)的代码实现
人工智能·语言模型·自然语言处理
2301_821919925 小时前
深度学习(四)
pytorch·深度学习
从孑开始5 小时前
ManySpeech.MoonshineAsr 使用指南
人工智能·ai·c#·.net·私有化部署·语音识别·onnx·asr·moonshine
涛涛讲AI5 小时前
一段音频多段字幕,让音频能够流畅自然对应字幕 AI生成视频,扣子生成剪映视频草稿
人工智能·音视频·语音识别
可触的未来,发芽的智生5 小时前
新奇特:黑猫警长的纳米世界,忆阻器与神经网络的智慧
javascript·人工智能·python·神经网络·架构
WWZZ20256 小时前
快速上手大模型:机器学习2(一元线性回归、代价函数、梯度下降法)
人工智能·算法·机器学习·计算机视觉·机器人·大模型·slam