计算语言模型计算每秒钟生成的token数量it/s

main() 函数的stream循环中,我们可以计算每秒钟生成的token数量,然后输出 it/s。在流式生成过程中,我们可以使用Python的time模块来计算速度。在测试时,生成速度会受到多个因素的影响,包括设备性能、模型大小、输入文本长度等。

python 复制代码
import os
import torch
import platform
from colorama import Fore, Style
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
import time


def init_model():
    print("init model ...")
    model = AutoModelForCausalLM.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat",
        torch_dtype=torch.float16,
        device_map="cuda",
        trust_remote_code=True
    )

    model.generation_config = GenerationConfig.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat",
        use_fast=False,
        trust_remote_code=True
    )
    return model, tokenizer


def clear_screen():
    if platform.system() == "Windows":
        os.system("cls")
    else:
        os.system("clear")
    print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,clear 清空历史,CTRL+C 中断生成,stream 开关流式生成,exit 结束。")
    return []


def main(stream=True):
    model, tokenizer = init_model()

    messages = clear_screen()
    while True:
        prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
        if prompt.strip() == "exit":
            break
        if prompt.strip() == "clear":
            messages = clear_screen()
            continue
        print(Fore.CYAN + Style.BRIGHT + "\nBaichuan:" + Style.NORMAL, end='')
        if prompt.strip() == "stream":
            stream = not stream
            print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
            continue
        messages.append({"role": "user", "content": prompt})
        if stream:
            position = 0
            try:
                start_time = time.time()
                total_tokens = 0
                for response in model.chat(tokenizer, messages, stream=True):
                    print(response[position:], end='', flush=True)
                    position = len(response)
                    total_tokens += len(tokenizer(response, return_tensors='pt')['input_ids'][0])
                    if torch.backends.mps.is_available():
                        torch.mps.empty_cache()
                end_time = time.time()
                elapsed_time = end_time - start_time
                tokens_per_second = total_tokens / elapsed_time
                print(f"\n\n生成速度:{tokens_per_second:.2f} tokens/s")
            except KeyboardInterrupt:
                pass
            print()

        else:
            response = model.chat(tokenizer, messages)
            print(response)
            if torch.backends.mps.is_available():
                torch.mps.empty_cache()
        messages.append({"role": "assistant", "content": response})

    print(Style.RESET_ALL)


if __name__ == "__main__":
    main()
相关推荐
聆风吟º6 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子6 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能7 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5777 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
h64648564h7 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切7 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
学电子她就能回来吗9 小时前
深度学习速成:损失函数与反向传播
人工智能·深度学习·学习·计算机视觉·github
Coder_Boy_10 小时前
TensorFlow小白科普
人工智能·深度学习·tensorflow·neo4j
大模型玩家七七10 小时前
梯度累积真的省显存吗?它换走的是什么成本
java·javascript·数据库·人工智能·深度学习
kkzhang10 小时前
Concept Bottleneck Models-概念瓶颈模型用于可解释决策:进展、分类体系 与未来方向综述
深度学习