Transformers CLI 与工具体系
相关文章:
Hugging Face Transformers 源码全景解读
01-Hugging Face Transformers 核心基础设施深度分析
02-Hugging Face Transformers 配置系统深度分析
03-Hugging Face Transformers 模型系统深度分析
04-Hugging Face Transformers 注意力与掩码系统深度分析
05-Hugging Face Transformers 缓存系统深度分析
06-Hugging Face Transformers 生成系统深度分析
07-Hugging Face Transformers 分词器系统深度分析
08-Hugging Face Transformers 多模态处理系统深度分析
09-Hugging Face Transformers 训练系统深度分析
10-Hugging Face Transformers 量化系统深度分析
11-Hugging Face Transformers 分布式与并行系统深度分析
12-Hugging Face Transformers之Pipeline 推理管道深入分析
13-Hugging Face Transformers之AutoModel 自动分发机制深入分析
14-Hugging Face Transformers 模型实现模式深度分析
15-Hugging Face Transformers之CLI 与工具架构总览
16-Hugging Face Transformers之测试体系架构总览
17-Hugging Face Transformers之BERT 案例详解:Transformers 框架全模块串联
18-Hugging Face Transformers之GPT-2 案例详解:Decoder-only 自回归模型的完整生命周期
19-Hugging Face Transformers之Qwen3.5-MoE 系列详解:混合专家 + 线性注意力 + 多模态的完整生命周期
CLI 与工具架构总览
#mermaid-svg-LAOeaHSPxWfGV87W{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-LAOeaHSPxWfGV87W .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-LAOeaHSPxWfGV87W .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-LAOeaHSPxWfGV87W .error-icon{fill:#552222;}#mermaid-svg-LAOeaHSPxWfGV87W .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-LAOeaHSPxWfGV87W .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-LAOeaHSPxWfGV87W .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-LAOeaHSPxWfGV87W .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-LAOeaHSPxWfGV87W .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-LAOeaHSPxWfGV87W .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-LAOeaHSPxWfGV87W .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-LAOeaHSPxWfGV87W .marker{fill:#333333;stroke:#333333;}#mermaid-svg-LAOeaHSPxWfGV87W .marker.cross{stroke:#333333;}#mermaid-svg-LAOeaHSPxWfGV87W svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-LAOeaHSPxWfGV87W p{margin:0;}#mermaid-svg-LAOeaHSPxWfGV87W .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-LAOeaHSPxWfGV87W .cluster-label text{fill:#333;}#mermaid-svg-LAOeaHSPxWfGV87W .cluster-label span{color:#333;}#mermaid-svg-LAOeaHSPxWfGV87W .cluster-label span p{background-color:transparent;}#mermaid-svg-LAOeaHSPxWfGV87W .label text,#mermaid-svg-LAOeaHSPxWfGV87W span{fill:#333;color:#333;}#mermaid-svg-LAOeaHSPxWfGV87W .node rect,#mermaid-svg-LAOeaHSPxWfGV87W .node circle,#mermaid-svg-LAOeaHSPxWfGV87W .node ellipse,#mermaid-svg-LAOeaHSPxWfGV87W .node polygon,#mermaid-svg-LAOeaHSPxWfGV87W .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-LAOeaHSPxWfGV87W .rough-node .label text,#mermaid-svg-LAOeaHSPxWfGV87W .node .label text,#mermaid-svg-LAOeaHSPxWfGV87W .image-shape .label,#mermaid-svg-LAOeaHSPxWfGV87W .icon-shape .label{text-anchor:middle;}#mermaid-svg-LAOeaHSPxWfGV87W .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-LAOeaHSPxWfGV87W .rough-node .label,#mermaid-svg-LAOeaHSPxWfGV87W .node .label,#mermaid-svg-LAOeaHSPxWfGV87W .image-shape .label,#mermaid-svg-LAOeaHSPxWfGV87W .icon-shape .label{text-align:center;}#mermaid-svg-LAOeaHSPxWfGV87W .node.clickable{cursor:pointer;}#mermaid-svg-LAOeaHSPxWfGV87W .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-LAOeaHSPxWfGV87W .arrowheadPath{fill:#333333;}#mermaid-svg-LAOeaHSPxWfGV87W .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-LAOeaHSPxWfGV87W .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-LAOeaHSPxWfGV87W .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-LAOeaHSPxWfGV87W .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-LAOeaHSPxWfGV87W .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-LAOeaHSPxWfGV87W .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-LAOeaHSPxWfGV87W .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-LAOeaHSPxWfGV87W .cluster text{fill:#333;}#mermaid-svg-LAOeaHSPxWfGV87W .cluster span{color:#333;}#mermaid-svg-LAOeaHSPxWfGV87W div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-LAOeaHSPxWfGV87W .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-LAOeaHSPxWfGV87W rect.text{fill:none;stroke-width:0;}#mermaid-svg-LAOeaHSPxWfGV87W .icon-shape,#mermaid-svg-LAOeaHSPxWfGV87W .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-LAOeaHSPxWfGV87W .icon-shape p,#mermaid-svg-LAOeaHSPxWfGV87W .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-LAOeaHSPxWfGV87W .icon-shape .label rect,#mermaid-svg-LAOeaHSPxWfGV87W .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-LAOeaHSPxWfGV87W .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-LAOeaHSPxWfGV87W .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-LAOeaHSPxWfGV87W :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 服务层
渲染层
命令层
入口层
cli/transformers.py
typer_factory
Chat
cli/chat.py
Serve
cli/serve.py
Download
cli/download.py
Env
cli/system.py
Version
cli/system.py
AddNewModelLike
cli/add_new_model_like.py
RichInterface
终端富文本
Console / Live / Markdown
OpenAI 兼容 API
transformers serve
AsyncInferenceClient
一、模块总览
Transformers CLI 是 Hugging Face Transformers 库提供的命令行工具集,基于 typer 框架构建,遵循 Hugging Face 生态统一的 CLI 规范。它提供以下核心命令:
| 命令 | 源文件 | 功能 |
|---|---|---|
transformers chat |
cli/chat.py |
交互式对话界面 |
transformers serve |
cli/serve.py |
启动 OpenAI 兼容 API 服务 |
transformers download |
cli/download.py |
下载模型和分词器 |
transformers env |
cli/system.py |
打印环境信息 |
transformers version |
cli/system.py |
打印版本号 |
transformers add-new-model-like |
cli/add_new_model_like.py |
基于现有模型创建新模型 |
二、CLI 入口 --- cli/transformers.py
2.1 架构设计
CLI 入口使用 huggingface_hub 提供的 typer_factory 创建 Typer 应用,确保与 huggingface-cli 风格一致:
python
from huggingface_hub import check_cli_update, typer_factory
app = typer_factory(help="Transformers CLI")
# 注册各子命令
app.command()(add_new_model_like) # 函数式命令
app.command(name="chat")(Chat) # 类式命令,显式指定名称
app.command()(download)
app.command()(env)
app.command(name="serve")(Serve)
app.command()(version)
def main():
check_cli_update("transformers") # 检查 CLI 是否有新版本
app()
设计要点:
typer_factory是 Hugging Face 统一 CLI 工厂方法,保证各 HF 库 CLI 风格一致check_cli_update在每次运行时检查更新,提示用户升级- 命令注册支持两种模式:函数直接注册 (如
download、env)和类注册 (如Chat、Serve) - 类作为命令时,Typer 会将
__init__参数自动转为 CLI 参数
三、对话命令 --- cli/chat.py
3.1 核心类 Chat
Chat 类实现了一个基于终端的交互式对话界面,连接到 transformers serve 启动的服务端:
python
class Chat:
def __init__(
self,
model_id: Annotated[str, typer.Argument(help="ID of the model to use")],
base_url: Annotated[str | None, typer.Argument(help="Base url to connect to")] = "http://localhost:8000",
generate_flags: Annotated[list[str] | None, typer.Argument(help="Flags to pass to generate")] = None,
user: Annotated[str | None, typer.Option(help="Username")] = None,
system_prompt: Annotated[str | None, typer.Option(help="System prompt")] = None,
save_folder: Annotated[str, typer.Option(help="Folder to save chat history")] = "./chat_history/",
examples_path: Annotated[str | None, typer.Option(help="Path to yaml file with examples")] = None,
generation_config: Annotated[str | None, typer.Option(help="Path to generation config")] = None,
):
架构特点:
- Chat 本身是客户端 ,不加载模型,而是通过 HTTP 连接到
transformers serve服务 - 使用
AsyncInferenceClient(来自huggingface_hub)进行异步流式通信 - 默认连接
localhost:8000,启动时执行健康检查
3.2 RichInterface --- 终端富文本渲染
python
class RichInterface:
def __init__(self, model_id: str, user_id: str, base_url: str):
self._console = Console()
# ...
async def stream_output(self, stream):
"""流式输出模型回复,实时 Markdown 渲染"""
with Live(console=self._console, refresh_per_second=4) as live:
text = ""
async for token in await stream:
outputs = token.choices[0].delta.content
# 转义 <think/> 等标签以正确渲染
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
text += outputs
# Markdown 渲染:行尾加双空格解决 \n 换行问题
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
live.update(markdown, refresh=True)
# 显示 token/s 统计
tok_per_sec = completion_tokens / elapsed
self._console.print(f"[dim]{completion_tokens} tokens in {elapsed:.1f}s ({tok_per_sec:.1f} tok/s)[/dim]")
def print_model_load(self, model: str):
"""通过 SSE 流式显示模型加载进度"""
response = requests.post(f"{self.base_url}/load_model", json={"model": model}, stream=True)
# 解析 SSE 事件,显示下载/加载进度条
关键设计:
- 使用
rich库实现 Markdown 实时渲染、进度条、彩色输出 - 流式输出时每秒刷新 4 次,平衡性能与流畅度
- 模型加载进度通过
/load_modelSSE 端点获取,支持分阶段显示(processor → config → download → weights)
3.3 内置命令系统
Chat 支持以下交互命令:
| 命令 | 功能 |
|---|---|
!help |
显示帮助 |
!clear |
清空对话历史 |
!status |
显示模型和生成配置 |
!set arg=val |
动态修改生成参数 |
!example NAME |
加载预设示例 |
!save [NAME] |
保存对话历史 |
!exit |
退出 |
3.4 辅助函数
python
def parse_generate_flags(generate_flags: list[str] | None) -> dict:
"""将 CLI 标志解析为 generate() 的 kwargs 字典
示例: ["max_new_tokens=100", "do_sample=False"] -> {"max_new_tokens": 100, "do_sample": False}
"""
# 1. 提取 flag=value 对
# 2. 处理类型: 布尔值小写化、None→null、字符串加引号
# 3. 拼接为 JSON 字符串并解析
generate_flags_string = "{" + ", ".join([f"{k}: {v}" for k, v in generate_flags_as_dict.items()]) + "}"
return json.loads(generate_flags_string)
def load_generation_config(generation_config: str | None) -> GenerationConfig:
"""加载生成配置:支持本地 JSON 文件或 HuggingFace 仓库名"""
if generation_config is None:
return GenerationConfig()
if ".json" in generation_config:
return GenerationConfig.from_pretrained(dirname, filename)
else:
return GenerationConfig.from_pretrained(generation_config)
四、服务命令 --- cli/serve.py 与 cli/serving/
4.1 Serve 类 --- 服务入口
Serve 是最复杂的 CLI 命令,启动一个兼容 OpenAI API 的 FastAPI 服务器:
python
class Serve:
def __init__(
self,
force_model: Annotated[str | None, typer.Argument(help="Model to preload")] = None,
# 模型选项
continuous_batching: Annotated[bool, typer.Option(help="Enable continuous batching")] = False,
cb_block_size: Annotated[int | None, typer.Option(help="KV cache block size")] = None,
cb_num_blocks: Annotated[int | None, typer.Option(help="Number of KV cache blocks")] = None,
attn_implementation: Annotated[str | None, typer.Option(help="Attention implementation")] = None,
compile: Annotated[bool, typer.Option(help="Enable torch.compile")] = False,
quantization: Annotated[str | None, typer.Option(help="Quantization: bnb-4bit/bnb-8bit")] = None,
device: Annotated[str, typer.Option(help="Device")] = "auto",
dtype: Annotated[str | None, typer.Option(help="Override dtype")] = "auto",
model_timeout: Annotated[int, typer.Option(help="Idle model unload timeout")] = 300,
# 服务器选项
host: Annotated[str, typer.Option(help="Listen address")] = "localhost",
port: Annotated[int, typer.Option(help="Listen port")] = 8000,
enable_cors: Annotated[bool, typer.Option(help="Enable CORS")] = False,
log_level: Annotated[str, typer.Option(help="Logging level")] = "warning",
default_seed: Annotated[int | None, typer.Option(help="Default torch seed")] = None,
non_blocking: Annotated[bool, typer.Option(hidden=True)] = False, # 测试用
):
架构层次:
Serve (CLI入口)
├── ModelManager --- 模型加载、缓存、生命周期管理
├── GenerationState --- 生成状态(普通/CB模式管理)
│ ├── GenerateManager --- 顺序生成(model.generate)
│ └── CBGenerateManager --- 连续批处理生成
├── ChatCompletionHandler --- /v1/chat/completions 端点
├── CompletionHandler --- /v1/completions 端点
├── ResponseHandler --- /v1/responses 端点
├── TranscriptionHandler --- /v1/audio/transcriptions 端点
└── build_server() --- FastAPI 应用工厂
4.2 ModelManager --- 模型生命周期管理
python
class ModelManager:
"""模型加载、缓存、自动卸载管理器"""
def __init__(self, device, dtype, trust_remote_code, attn_implementation,
quantization, model_timeout, force_model):
self.loaded_models: dict[str, TimedModel] = {} # 模型缓存
self._model_locks: dict[str, threading.Lock] = {} # 并发加载锁
self._loading_subscribers: dict[str, list[asyncio.Queue]] = {} # SSE 订阅者
def load_model_and_processor(self, model_id_and_revision, ...):
"""加载模型+处理器(或从缓存返回),重置超时计时器"""
with lock: # 防止并发重复加载
if model_id_and_revision not in self.loaded_models:
processor = self._load_processor(model_id_and_revision)
model = self._load_model(model_id_and_revision, ...)
self.loaded_models[model_id_and_revision] = TimedModel(
model, timeout_seconds=self.model_timeout, processor=processor,
on_unload=lambda key: self.loaded_models.pop(key, None)
)
else:
self.loaded_models[model_id_and_revision].reset_timer()
async def load_model_streaming(self, model_id_and_revision):
"""SSE 流式加载:支持多订阅者同时等待同一模型加载"""
# Case 1: 已缓存 → 单个 ready 事件
# Case 2: 正在加载 → 加入现有订阅者队列
# Case 3: 首次请求 → 启动加载,广播给所有订阅者
TimedModel --- 自动卸载机制:
python
class TimedModel:
"""包装模型+处理器,超时自动卸载"""
def __init__(self, model, timeout_seconds, processor, on_unload):
self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached)
self._timer.start()
def reset_timer(self):
"""每次请求重置计时器"""
self._timer.cancel()
self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached)
self._timer.start()
def _timeout_reached(self):
"""超时后自动删除模型释放 GPU 内存"""
self.delete_model()
4.3 GenerationState --- 生成状态管理
python
class GenerationState:
"""管理每个模型的生成管理器实例"""
def __init__(self, continuous_batching, compile, cb_config):
self._generate_managers: dict[str, GenerateManager] = {}
self._cb_manager: CBGenerateManager | None = None
def use_continuous_batching(self, model, modality) -> bool:
"""判断是否使用连续批处理(仅支持 LLM + 模型实现 init_continuous_batching)"""
def get_manager(self, model_id, use_cb) -> BaseGenerateManager:
"""获取/创建生成管理器(每模型一个,懒初始化)"""
4.4 生成管理器层级
BaseGenerateManager (ABC)
├── GenerateManager --- 顺序生成
│ └── InferenceThread --- 持久线程(torch.compile 需要)
│ └── DirectStreamer --- model.generate() 的流式输出器
└── CBGenerateManager --- 连续批处理
└── CBStreamer --- CB 输出的流式解码器
DirectStreamer --- 流式输出核心:
python
class DirectStreamer:
"""实现 model.generate() 的 put/end 协议"""
def __init__(self, tokenizer, loop, queue, skip_special_tokens=True, tool_config=None):
self._decode_stream = DecodeStream([], skip_special_tokens) # Rust 解码器 O(1)
self._stc_id = tool_config["stc_id"] if tool_config else None # 工具调用起始标记
self._etc_id = tool_config["etc_id"] if tool_config else None # 工具调用结束标记
def put(self, value):
"""generate() 每步调用:解码 token,过滤工具调用标记,推入队列"""
for token_id in value.tolist():
if token_id == self._stc_id: self._inside_tool_call = True
elif token_id == self._etc_id: self._inside_tool_call = False
text = self._decode_stream.step(self._tokenizer, token_id)
if text and not self._inside_tool_call:
self._loop.call_soon_threadsafe(self._queue.put_nowait, text)
def cancel(self):
"""客户端断开时取消生成,释放 GPU"""
self._cancelled.set()
4.5 请求处理器
BaseHandler --- 处理器基类:
python
class BaseHandler:
"""共享逻辑:模型解析、生成配置构建、SSE 格式化"""
def _validate_request(self, body): ... # 验证请求字段
def _resolve_model(self, body): ... # 解析并加载模型
def _build_generation_config(self, body, model_generation_config): ...
def get_processor_inputs_from_messages(self, messages, modality): ...
@staticmethod
def chunk_to_sse(chunk): ... # 格式化为 SSE data: 行
ChatCompletionHandler --- Chat Completions 端点:
python
class ChatCompletionHandler(BaseHandler):
async def handle_request(self, body, request_id):
# 1. 验证请求
# 2. 解析模型,检测模态(LLM/VLM/MULTIMODAL)
# 3. 判断是否使用连续批处理
# 4. 应用 chat_template 处理消息
# 5. 构建生成配置
# 6. 分发到流式/非流式处理
CompletionHandler --- Legacy Completions 端点:
- 接受纯文本
prompt(无 chat template) - 支持
suffix参数用于 fill-in-the-middle
ResponseHandler --- Responses API 端点:
- 支持更灵活的输入格式(字符串/列表/多轮对话)
- 工具定义格式转换(扁平→嵌套)
- 流式输出使用 OpenAI Responses API 事件序列
TranscriptionHandler --- 语音转写端点:
- 独立实现(不继承 BaseHandler),因为使用 multipart 表单而非 JSON
- 通过 librosa 加载音频,使用模型 generate() 进行推理
4.6 FastAPI 路由 --- serving/server.py
python
def build_server(model_manager, chat_handler, completion_handler,
response_handler, transcription_handler, generation_state, enable_cors) -> FastAPI:
app = FastAPI(lifespan=lifespan) # lifespan 中调用 model_manager.shutdown()
# 路由注册
app.post("/v1/chat/completions")(chat_completions)
app.post("/v1/completions")(completions)
app.post("/v1/responses")(responses)
app.post("/v1/audio/transcriptions")(audio_transcriptions)
app.post("/load_model")(load_model) # SSE 流式加载
app.post("/reset")(reset) # 重置所有模型
app.get("/v1/models")(list_models) # 列出可用模型
app.get("/health")(health) # 健康检查
# 中间件
app.middleware("http")(request_id_middleware) # 请求 ID 追踪
# 异常处理
app.exception_handler(CBWorkerDeadError)(_cb_dead_handler) # CB worker 死亡 → 503
4.7 模态检测与工具调用
python
class Modality(enum.Enum):
LLM = "LLM"
VLM = "VLM"
MULTIMODAL = "MULTIMODAL" # 文本+图像+视频+音频
STT = "STT"
TTS = "TTS"
def get_tool_call_config(processor, model) -> dict | None:
"""获取工具调用配置:从 tokenizer 的 stc_token/etc_token/response_schema 读取
或从 _TOOL_CALL_FALLBACKS 回退查找"""
五、下载命令 --- cli/download.py
最简单的 CLI 命令,直接委托给 AutoModel/AutoTokenizer:
python
def download(model_id, cache_dir=None, force_download=False, trust_remote_code=False):
"""下载模型和分词器到本地缓存"""
from ..models.auto import AutoModel, AutoTokenizer
AutoModel.from_pretrained(model_id, cache_dir=cache_dir,
force_download=force_download, trust_remote_code=trust_remote_code)
AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir,
force_download=force_download, trust_remote_code=trust_remote_code)
设计要点:
- 延迟导入(
from ..models.auto import ...),避免 CLI 启动时加载整个库 - 利用
from_pretrained的内置缓存机制,无需额外实现下载逻辑
六、系统命令 --- cli/system.py
6.1 env 命令
收集并打印运行环境信息,用于 GitHub Issue 报告:
python
def env(accelerate_config_file=None):
info = {
"`transformers` version": __version__,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"Huggingface_hub version": huggingface_hub.__version__,
"Safetensors version": safetensors_version,
"Accelerate version": accelerate_version,
"Accelerate config": accelerate_config_str,
"DeepSpeed version": deepspeed_version,
"PyTorch version (accelerator?)": f"{pt_version} ({pt_accelerator})",
"Using distributed or parallel set-up in script?": "<fill in>",
}
# 自动检测 GPU/XPU/NPU/HPU 类型和版本
6.2 version 命令
python
def version():
print(__version__)
七、添加新模型命令 --- cli/add_new_model_like.py
7.1 功能概述
这是一个开发工具,基于现有模型创建新模型的完整脚手架,包括:
- 模块文件 :
modular_xxx.py(通过继承复用旧模型类) - 初始化文件 :
__init__.py(延迟加载模式) - Auto 映射 :在
auto_mappings.py中注册新模型 - 测试文件:复制并替换旧模型的测试
- 文档文件:创建模板文档
- 代码格式化:运行 ruff check/format
7.2 核心类 ModelInfos
python
class ModelInfos:
"""检索现有模型类的信息"""
def __init__(self, lowercase_name: str):
self.lowercase_name = lowercase_name.lower().replace(" ", "_").replace("-", "_")
self.config_class = CONFIG_MAPPING_NAMES[self.lowercase_name]
self.camelcase_name = self.config_class.replace("Config", "")
# 从 Auto 映射获取 tokenizer/image_processor/processor 类名
self.tokenizer_class = TOKENIZER_MAPPING_NAMES.get(self.lowercase_name, None)
self.image_processor_classes = IMAGE_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
# ...
7.3 交互式输入流程
python
def get_user_input():
# 1. 选择要复制的旧模型类型(带模糊匹配提示)
old_model_type = input("What model would you like to duplicate?")
# 2. 输入新模型名称
new_lowercase_name = get_user_field("What is the new model name?")
# 3. 输入论文中显示的名称
new_model_paper_name = get_user_field("What is the fully cased name?")
# 4. 逐个确认是否创建新的 tokenizer/image_processor/processor
# 5. 返回 (old_model_infos, new_lowercase_name, new_model_paper_name, filenames_to_add)
7.4 模块化文件生成
使用 libcst(Python AST 解析器)分析旧模型文件中的类定义,生成 modular 文件:
python
def create_modular_file(repo_path, old_model_infos, new_lowercase_name, filenames_to_add):
"""创建 modular 文件:继承旧模型的所有类"""
for filename, to_add in filenames_to_add:
if to_add:
imports, body, public_classes = find_modular_structure(
old_folder_root / filename, old_model_infos, new_cased_name
)
# 生成的文件结构:
# from ..old_model import OldModelClass
# class NewModelClass(OldModelClass):
# pass
# __all__ = ["NewModelClass", ...]
7.5 完整执行流程
_add_new_model_like_internal():
1. 创建新模型文件夹 src/transformers/models/new_model/
2. 创建 modular_new_model.py
3. 创建 __init__.py
4. 注册到 models/__init__.py
5. 添加到 auto_mappings.py
6. 创建测试文件 tests/models/new_model/
7. 创建文档文件 docs/source/en/model_doc/new_model.md
8. 运行 ruff check/format
9. 运行 modular_model_converter.py
八、模块间关系
┌─────────────────────────────────────────────────────┐
│ CLI 入口层 │
│ transformers.py (typer app) │
│ ├── chat.py ──── Chat + RichInterface │
│ ├── serve.py ─── Serve │
│ ├── download.py ─ download() │
│ ├── system.py ── env() + version() │
│ └── add_new_model_like.py ─ add_new_model_like() │
└──────────────────────┬──────────────────────────────┘
│
┌──────────────────────▼──────────────────────────────┐
│ Serving 子系统 │
│ serving/ │
│ ├── server.py ──── FastAPI 应用工厂 │
│ ├── model_manager.py ─ 模型生命周期管理 │
│ ├── chat_completion.py ─ Chat Completions 处理 │
│ ├── completion.py ─── Legacy Completions 处理 │
│ ├── response.py ───── Responses API 处理 │
│ ├── transcription.py ─ 语音转写处理 │
│ └── utils.py ──────── 生成管理器/流式输出/工具调用 │
└──────────────────────┬──────────────────────────────┘
│
┌──────────────────────▼──────────────────────────────┐
│ Transformers 核心库 │
│ AutoModel / AutoTokenizer / GenerationConfig │
│ model.generate() / ContinuousBatchingManager │
│ PreTrainedModel / ProcessorMixin │
└─────────────────────────────────────────────────────┘
关键依赖关系:
chat.py依赖huggingface_hub.AsyncInferenceClient和richserve.py依赖fastapi、uvicorn、openai(类型定义)download.py依赖AutoModel、AutoTokenizeradd_new_model_like.py依赖libcst(AST 解析)和 Auto 映射注册表serving/utils.py的DirectStreamer依赖tokenizers.DecodeStream(Rust 解码器)