在 main()
函数的stream循环中,我们可以计算每秒钟生成的token数量,然后输出 it/s
。在流式生成过程中,我们可以使用Python的time
模块来计算速度。在测试时,生成速度会受到多个因素的影响,包括设备性能、模型大小、输入文本长度等。
python
import os
import torch
import platform
from colorama import Fore, Style
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
import time
def init_model():
print("init model ...")
model = AutoModelForCausalLM.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
torch_dtype=torch.float16,
device_map="cuda",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat"
)
tokenizer = AutoTokenizer.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
use_fast=False,
trust_remote_code=True
)
return model, tokenizer
def clear_screen():
if platform.system() == "Windows":
os.system("cls")
else:
os.system("clear")
print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,clear 清空历史,CTRL+C 中断生成,stream 开关流式生成,exit 结束。")
return []
def main(stream=True):
model, tokenizer = init_model()
messages = clear_screen()
while True:
prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
if prompt.strip() == "exit":
break
if prompt.strip() == "clear":
messages = clear_screen()
continue
print(Fore.CYAN + Style.BRIGHT + "\nBaichuan:" + Style.NORMAL, end='')
if prompt.strip() == "stream":
stream = not stream
print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
continue
messages.append({"role": "user", "content": prompt})
if stream:
position = 0
try:
start_time = time.time()
total_tokens = 0
for response in model.chat(tokenizer, messages, stream=True):
print(response[position:], end='', flush=True)
position = len(response)
total_tokens += len(tokenizer(response, return_tensors='pt')['input_ids'][0])
if torch.backends.mps.is_available():
torch.mps.empty_cache()
end_time = time.time()
elapsed_time = end_time - start_time
tokens_per_second = total_tokens / elapsed_time
print(f"\n\n生成速度:{tokens_per_second:.2f} tokens/s")
except KeyboardInterrupt:
pass
print()
else:
response = model.chat(tokenizer, messages)
print(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
messages.append({"role": "assistant", "content": response})
print(Style.RESET_ALL)
if __name__ == "__main__":
main()