accelerate 分布式技巧实战--部署ChatGLM-6B(三)

accelerate 分布式技巧实战--部署ChatGLM-6B(三)

基础环境

bash 复制代码
torch==2.0.0+cu118
transformers==4.28.1
accelerate==0.18.0
Tesla T4 15.3G
内存:11.8G

下载相关文件:

python 复制代码
git clone https://github.com/THUDM/ChatGLM-6B
cd ChatGLM-6B

git clone --depth=1 https://huggingface.co/THUDM/chatglm-6b THUDM/chatglm-6b
git clone --depth=1 https://huggingface.co/THUDM/chatglm-6b-int4 THUDM/chatglm-6b-int4

pip install -r requirements.txt
pip install gradio
pip install accelerate

正常情况下,我们使用Chat-GLM需要的显存大于13G,内存没有评估过,但上述的肯定是不够的,16G应该可以。

方案一:量化模型

python 复制代码
from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
import time

tokenizer = AutoTokenizer.from_pretrained("./THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("./THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()

model = model.eval()

def predict(input, history=None):
    print(f'predict started: {time.time()}');
    if history is None:
        history = []
    response, history = model.chat(tokenizer, input, history)
    return response, history

while True:
  text = input(">>用户:")
  response, history = model.chat(tokenizer, input, history)
  print(">>CHatGLM:", response)

GPU使用4.9G,内存使用5.5G。

方案二:一块GPU

python 复制代码
from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
import time


tokenizer = AutoTokenizer.from_pretrained("./THUDM/chatglm-6b", trust_remote_code=True)
config = AutoConfig.from_pretrained("./THUDM/chatglm-6b", trust_remote_code=True)
with init_empty_weights():
  model = AutoModel.from_config(config, trust_remote_code=True)

for name, _ in model.named_parameters():
  print(name)
# device_map = infer_auto_device_map(model, no_split_module_classes=["GLMBlock"])
# print(device_map)
device_map = {'transformer.word_embeddings': 0, 'transformer.layers.0': 0, 'transformer.layers.1': 0, 'transformer.layers.2': 0, 'transformer.layers.3': 0, 'transformer.layers.4': 0, 'transformer.layers.5': 0, 'transformer.layers.6': 0, 'transformer.layers.7': 0, 'transformer.layers.8': 0, 'transformer.layers.9': 0, 'transformer.layers.10': 0, 'transformer.layers.11': 0, 'transformer.layers.12': 0, 'transformer.layers.13': 0, 'transformer.layers.14': 0, 'transformer.layers.15': 0, 'transformer.layers.16': 0, 'transformer.layers.17': 0, 'transformer.layers.18': 0, 'transformer.layers.19': 0, 'transformer.layers.20': 0, 'transformer.layers.21': 'cpu', 'transformer.layers.22': 'cpu', 'transformer.layers.23': 'cpu', 'transformer.layers.24': 'cpu', 'transformer.layers.25': 'cpu', 'transformer.layers.26': 'cpu', 'transformer.layers.27': 'cpu', 'transformer.final_layernorm': 'cpu', 'lm_head': 'cpu'}
model = load_checkpoint_and_dispatch(model, "./THUDM/chatglm-6b", device_map=device_map, offload_folder="offload", offload_state_dict=True, no_split_module_classes=["GLMBlock"]).half()

def predict(input, history=None):
    print(f'predict started: {time.time()}');
    if history is None:
        history = []
    response, history = model.chat(tokenizer, input, history)
    return response, history

while True:
  history = None
  text = input(">>用户:")
  response, history = model.chat(tokenizer, text, history)
  print(">>CHatGLM:", response)

GPU使用9.7G,内存使用5.9G。第一轮输入你好后GPU使用11.2G。

方案三:accelerate,多块GPU

python 复制代码
import os
os.environ["cuda_visible_devices"] = "0,1"

from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
# import gradio as gr
# import torch
import time


tokenizer = AutoTokenizer.from_pretrained(".\\chatglm-6b\\", trust_remote_code=True)
config = AutoConfig.from_pretrained(".\\chatglm-6b\\", trust_remote_code=True)
with init_empty_weights():
  model = AutoModel.from_config(config, trust_remote_code=True)

for name, _ in model.named_parameters():
  print(name)
# device_map = infer_auto_device_map(model, no_split_module_classes=["GLMBlock"])
# print(device_map)
# device_map = {'transformer.word_embeddings': 0, 'transformer.layers.0': 0, 'transformer.layers.1': 0, 'transformer.layers.2': 0, 'transformer.layers.3': 0, 'transformer.layers.4': 0, 'transformer.layers.5': 0, 'transformer.layers.6': 0, 'transformer.layers.7': 0, 'transformer.layers.8': 0, 'transformer.layers.9': 0, 'transformer.layers.10': 0, 'transformer.layers.11': 0, 'transformer.layers.12': 0, 'transformer.layers.13': 0, 'transformer.layers.14': 0, 'transformer.layers.15': 0, 'transformer.layers.16': 0, 'transformer.layers.17': 0, 'transformer.layers.18': 0, 'transformer.layers.19': 0, 'transformer.layers.20': 0, 'transformer.layers.21': 'cpu', 'transformer.layers.22': 'cpu', 'transformer.layers.23': 'cpu', 'transformer.layers.24': 'cpu', 'transformer.layers.25': 'cpu', 'transformer.layers.26': 'cpu', 'transformer.layers.27': 'cpu', 'transformer.final_layernorm': 'cpu', 'lm_head': 'cpu'}
model = load_checkpoint_and_dispatch(model, ".\\chatglm-6b\\", device_map="balanced", offload_folder="offload", offload_state_dict=True, no_split_module_classes=["GLMBlock"]).half()

def predict(input, history=None):
    print(f'predict started: {time.time()}')
    if history is None:
        history = []
    response, history = model.chat(tokenizer, input, history)
    return response, history

while True:
  history = None
  text = input(">>用户:")
  response, history = model.chat(tokenizer, text, history)
  print(">>CHatGLM:", response)

注意,这里我们设置设备映射为balanced,并只使用前两块GPU。显卡占用情况

参考

https://cloud.tencent.com/developer/article/2274903?areaSource=102001.17\&traceId=dUu9a81soH3zQ5nQGczRV

相关推荐
User_芊芊君子5 小时前
CANN数学计算基石ops-math深度解析:高性能科学计算与AI模型加速的核心引擎
人工智能·深度学习·神经网络·ai
小白|5 小时前
CANN与联邦学习融合:构建隐私安全的分布式AI推理与训练系统
人工智能·机器学习·自动驾驶
艾莉丝努力练剑5 小时前
hixl vs NCCL:昇腾生态通信库的独特优势分析
运维·c++·人工智能·cann
梦帮科技5 小时前
Node.js配置生成器CLI工具开发实战
前端·人工智能·windows·前端框架·node.js·json
程序员泠零澪回家种桔子5 小时前
Spring AI框架全方位详解
java·人工智能·后端·spring·ai·架构
Echo_NGC22375 小时前
【FFmpeg 使用指南】Part 3:码率控制策略与质量评估体系
人工智能·ffmpeg·视频·码率
纤纡.5 小时前
PyTorch 入门精讲:从框架选择到 MNIST 手写数字识别实战
人工智能·pytorch·python
大大大反派5 小时前
CANN 生态中的自动化部署引擎:深入 `mindx-sdk` 项目构建端到端 AI 应用
运维·人工智能·自动化
程序猿追5 小时前
深度解读 AIR (AI Runtime):揭秘 CANN 极致算力编排与调度的核心引擎
人工智能
2601_949593655 小时前
深入解析CANN-acl应用层接口:构建高效的AI应用开发框架
数据库·人工智能