Hugging Face Transformers 多模态处理系统深度分析
相关文章:
Hugging Face Transformers 源码全景解读
01-Hugging Face Transformers 核心基础设施深度分析
02-Hugging Face Transformers 配置系统深度分析
03-Hugging Face Transformers 模型系统深度分析
04-Hugging Face Transformers 注意力与掩码系统深度分析
05-Hugging Face Transformers 缓存系统深度分析
06-Hugging Face Transformers 生成系统深度分析
07-Hugging Face Transformers 分词器系统深度分析
目录
- 系统架构总览
- [ProcessorMixin --- 多模态处理器基类](#ProcessorMixin — 多模态处理器基类)
- [TypedDict kwargs 系统](#TypedDict kwargs 系统)
- 图像处理基础设施
- 图像处理后端(PIL/Torchvision)
- 图像变换
- 图像工具
- 特征提取系统
- 音频工具
- 视频处理系统
- 视觉工具
- 模块间关系与数据流
多模态处理系统架构总览
工具层
变换层
后端层
子处理器层
统一入口层
ProcessorMixin
PushToHubMixin
call 多模态统一入口
Tokenizer
文本处理
ImageProcessor
图像处理
VideoProcessor
视频处理
FeatureExtractor
特征提取
PilBackend
CPU后端
TorchvisionBackend
GPU后端
Resize
Rescale
Normalize
Pad
image_utils
video_utils
audio_utils
vision_utils
系统架构总览
Transformers 的多模态处理系统采用分层架构 设计,以 ProcessorMixin 为核心调度器,将不同模态(文本、图像、视频、音频)的处理逻辑解耦为独立的子处理器。整体架构如下:
┌─────────────────────┐
│ ProcessorMixin │ ← 多模态统一入口
│ (__call__ / _merge_ │
│ kwargs) │
└──────┬──────────────┘
┌────────────────┼────────────────┐
│ │ │
┌──────▼──────┐ ┌──────▼──────┐ ┌───────▼──────┐
│ Tokenizer │ │ ImageProc │ │ VideoProc │ ...
│ (文本处理) │ │ (图像处理) │ │ (视频处理) │
└─────────────┘ └──────┬──────┘ └───────┬──────┘
│ │
┌──────────▼──────────┐ │
│ Backend 层 │ │
│ TorchvisionBackend │ │
│ PilBackend │ │
└──────────┬──────────┘ │
│ │
┌──────────▼──────────┐ │
│ Transform 层 │ │
│ resize / normalize │ │
│ rescale / pad │ │
└──────────┬──────────┘ │
│ │
┌──────────▼──────────┐ │
│ Utils 层 │ │
│ image_utils │ │
│ video_utils │◄────┘
│ audio_utils │
└─────────────────────┘
核心设计原则:
- 模态解耦:每种模态有独立的处理器,通过 ProcessorMixin 统一调度
- 后端可替换:图像处理支持 PIL(CPU)和 Torchvision(GPU)两种后端
- TypedDict 驱动的参数系统:使用 Python TypedDict 实现类型安全的参数传递
- 优先级合并:kwargs 按优先级从高到低合并(调用时 > 模态专用 > 初始化时 > 默认值)
1. ProcessorMixin --- 多模态处理器基类
源文件:[processing_utils.py](file:///workspace/src/transformers/processing_utils.py)
模块职责
ProcessorMixin 是所有多模态处理器的基类,负责:
- 统一管理各模态的子处理器(tokenizer、image_processor、video_processor、feature_extractor)
- 将用户输入按模态分发到对应子处理器
- 合并各模态处理结果为统一的
BatchFeature输出 - 管理聊天模板(chat template)和对话式多模态输入
- 提供保存/加载(
save_pretrained/from_pretrained)功能
核心类/函数
ProcessorMixin 类
python
class ProcessorMixin(PushToHubMixin):
# 子处理器属性,由子类通过 __init__ 的 attributes 动态设置
tokenizer: Any
feature_extractor: Any
image_processor: Any
video_processor: Any
chat_template: str | dict[str, str] | None
valid_processor_kwargs = ProcessingKwargs # 默认的 kwargs 类型定义
__call__ 方法 --- 多模态统一入口
python
def __call__(
self,
images: ImageInput | None = None,
text: TextInput | ... | None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs: Unpack[ProcessingKwargs],
):
# 1. 合并所有来源的 kwargs(调用时 > 模态专用 > 初始化时 > 默认值)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs if hasattr(self, "tokenizer") else {},
**kwargs,
)
# 2. 建立模态到子处理器和 kwargs 的映射
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
}
# 3. 遍历所有子处理器,分发输入和对应的 kwargs
outputs = {}
for attribute_name in self.get_attributes():
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)
# 4. 合并所有子处理器输出为 BatchFeature
return BatchFeature(outputs)
设计要点:
- 音频模态通过
feature_extractor处理(历史原因,feature_extractor 同时负责音频特征提取) - 各子处理器输出合并时使用
dict.update(),因此 key 不能冲突 Unpack[ProcessingKwargs]提供了 IDE 友好的类型提示
_merge_kwargs 方法 --- 参数合并核心
这是整个多模态系统中最复杂的部分,负责将来自不同来源的参数按优先级合并:
python
def _merge_kwargs(
self,
ModelProcessorKwargs: ProcessingKwargs,
tokenizer_init_kwargs: dict | None = None,
**kwargs,
) -> dict[str, dict]:
参数优先级(从高到低):
| 优先级 | 来源 | 示例 |
|---|---|---|
| 1(最高) | 调用时直接传入的 flat kwargs | processor(..., padding="max_length") |
| 2 | 调用时传入的模态专用 kwargs | processor(..., text_kwargs={"padding": "max_length"}) |
| 3 | 子处理器初始化时的 kwargs | tokenizer = Tokenizer(..., padding="max_length") |
| 4(最低) | ProcessingKwargs 中的 _defaults |
_defaults = {"text_kwargs": {"padding": "max_length"}} |
合并流程:
python
# 步骤1:初始化输出字典和默认值字典
output_kwargs = {"text_kwargs": {}, "images_kwargs": {}, "audio_kwargs": {}, "videos_kwargs": {}}
default_kwargs = {"text_kwargs": {}, "images_kwargs": {}, "audio_kwargs": {}, "videos_kwargs": {}}
# 步骤2:从 ModelProcessorKwargs._defaults 填充默认值
for modality in default_kwargs:
default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
# 同时收集子处理器的 valid_kwargs 以确保一致性
modality_valid_kwargs = set(ModelProcessorKwargs.__annotations__[modality].__annotations__)
if modality in map_preprocessor_kwargs:
preprocessor = getattr(self, map_preprocessor_kwargs[modality], None)
preprocessor_valid_kwargs = getattr(preprocessor, "valid_kwargs", None)
modality_valid_kwargs.update(set(preprocessor_valid_kwargs.__annotations__))
# 步骤3:用 tokenizer 初始化 kwargs 覆盖默认值
for modality_key in modality_valid_kwargs:
if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs:
default_kwargs[modality][modality_key] = value
# 步骤4:处理 common_kwargs(跨模态共享参数)
common_kwargs = ModelProcessorKwargs._defaults.get("common_kwargs", {})
common_kwargs.update(kwargs.get("common_kwargs", {}))
if common_kwargs:
for kwarg in output_kwargs.values():
kwarg.update(common_kwargs)
# 步骤5:用调用时传入的 kwargs 覆盖(最高优先级)
# ... 按模态分发 flat kwargs 和嵌套 kwargs
# 步骤6:类型验证
for key, typed_dict_obj in ModelProcessorKwargs.__annotations__.items():
validate_typed_dict(typed_dict_obj, output_kwargs[key])
apply_chat_template 方法 --- 对话式多模态处理
python
def apply_chat_template(
self,
conversation: list[dict[str, str]] | list[list[dict[str, str]]],
chat_template: str | None = None,
...
) -> str:
此方法将 OpenAI 风格的多模态对话转换为模型输入:
- 模板渲染:使用 Jinja2 模板将对话格式化为文本
- 多模态提取:从对话内容中提取图像/视频/音频的 URL 或路径
- 输入加载 :调用
load_audio、fetch_images等加载实际数据 - 格式归一化 :将 OpenAI 的
image_url格式转换为 HuggingFace 的image格式 - 统一处理 :最终调用
self()将所有模态一起处理
get_attributes 方法 --- 动态属性发现
python
@classmethod
def get_attributes(cls):
# 从 __init__ 签名中提取所有模态属性
args_in_init = inspect.signature(cls.__init__).parameters.keys()
attributes = []
for sub_processor_type in args_in_init:
if any(modality in sub_processor_type for modality in MODALITY_TO_AUTOPROCESSOR_MAPPING.keys()):
attributes.append(sub_processor_type)
# 兼容旧版:从 <attribute>_class 类属性推断
if not attributes:
for attribute_name, value in cls.__dict__.items():
if value is not None and attribute_name.endswith("_class"):
inferred_attribute = attribute_name[:-len("_class")]
attributes.append(inferred_attribute)
return attributes
_LazyAutoProcessorMapping --- 延迟加载映射
python
class _LazyAutoProcessorMapping(dict):
_MAPPING_NAMES = {
"image_processor": ("transformers.models.auto.image_processing_auto", "AutoImageProcessor"),
"video_processor": ("transformers.models.auto.video_processing_auto", "AutoVideoProcessor"),
"feature_extractor": ("transformers.models.auto.feature_extraction_auto", "AutoFeatureExtractor"),
"audio_processor": ("transformers.models.auto.feature_extraction_auto", "AutoFeatureExtractor"),
"tokenizer": ("transformers.models.auto.tokenization_auto", "AutoTokenizer"),
}
通过延迟导入避免循环依赖------只有在实际访问时才导入对应的 Auto 类。
与其他模块的关系
- 依赖
image_utils(ImageInput)、video_utils(VideoInput)、audio_utils(AudioInput)定义输入类型 - 依赖
feature_extraction_utils(BatchFeature)作为输出容器 - 依赖
tokenization_utils_base(PaddingStrategy、TruncationStrategy 等)处理文本参数 - 被依赖 于所有具体模型的 Processor 类(如 LlavaProcessor、Qwen2VLProcessor 等)
2. TypedDict kwargs 系统
源文件:[processing_utils.py](file:///workspace/src/transformers/processing_utils.py)(L162-L475)
模块职责
TypedDict kwargs 系统是 Transformers v4.47+ 引入的类型安全参数传递机制 ,替代了之前松散的 **kwargs 传递方式,提供:
- IDE 自动补全和类型检查
- 运行时参数验证(通过
validate_typed_dict) - 按模态自动分发参数
- 自定义验证器(如
positive_int、padding_validator)
核心类
TextKwargs --- 文本处理参数
python
class TextKwargs(TypedDict, total=False):
text_pair: TextInput | PreTokenizedInput | list[TextInput] | ... | None
add_special_tokens: bool | None
padding: Annotated[bool | str | PaddingStrategy | None, padding_validator()] # 带验证器
truncation: Annotated[bool | str | TruncationStrategy | None, truncation_validator()]
max_length: Annotated[int | None, positive_int()] # 正整数验证
return_tensors: Annotated[str | TensorType | None, tensor_type_validator()]
# ... 更多参数
ImagesKwargs --- 图像处理参数
python
class ImagesKwargs(TypedDict, total=False):
do_convert_rgb: bool | None
do_resize: bool | None
size: Annotated[int | list[int] | tuple[int, ...] | dict[str, int] | None, image_size_validator()]
resample: Annotated[Union["PILImageResampling", int] | None, resampling_validator()]
do_rescale: bool | None
rescale_factor: float | None
do_normalize: bool | None
image_mean: float | list[float] | tuple[float, ...] | None
image_std: float | list[float] | tuple[float, ...] | None
device: Annotated[Union[str, "torch.device"] | None, device_validator()]
return_tensors: Annotated[str | TensorType | None, tensor_type_validator()]
# ... 更多参数
VideosKwargs --- 视频处理参数
python
class VideosKwargs(TypedDict, total=False):
# 继承了大部分图像参数,额外增加:
do_sample_frames: bool | None
video_metadata: Annotated[VideoMetadataType | None, video_metadata_validator()]
fps: Annotated[int | float | None, positive_any_number()]
num_frames: Annotated[int | None, positive_int()]
return_metadata: bool | None
AudioKwargs --- 音频处理参数
python
class AudioKwargs(TypedDict, total=False):
sampling_rate: Annotated[int | None, positive_int()]
raw_speech: Union["np.ndarray", list[float], ...] | None
padding: Annotated[bool | str | PaddingStrategy | None, padding_validator()]
max_length: Annotated[int | None, positive_int()]
truncation: Annotated[bool | str | TruncationStrategy | None, truncation_validator()]
return_attention_mask: bool | None
return_tensors: Annotated[str | TensorType | None, tensor_type_validator()]
ProcessingKwargs --- 聚合容器
python
class ProcessingKwargs(TypedDict, total=False):
_defaults = {} # 子类可覆盖以提供默认值
text_kwargs: TextKwargs = {**TextKwargs.__annotations__}
images_kwargs: ImagesKwargs = {**ImagesKwargs.__annotations__}
videos_kwargs: VideosKwargs = {**VideosKwargs.__annotations__}
audio_kwargs: AudioKwargs = {**AudioKwargs.__annotations__}
扩展机制
模型特定的 Processor 可以通过继承扩展 kwargs:
python
# 添加模型特有的图像参数
class ModelImagesKwargs(ImagesKwargs, total=False):
new_image_kwarg: Optional[bool]
# 添加模型特有的处理参数和默认值
class ModelProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: ModelImagesKwargs
_defaults = {
"images_kwargs": {
"new_image_kwarg": False,
},
"text_kwargs": {
"padding": "max_length",
},
}
验证器系统
Annotated 类型配合验证器实现运行时校验:
| 验证器 | 用途 |
|---|---|
positive_int() |
确保正整数 |
positive_any_number() |
确保正数(int 或 float) |
padding_validator() |
验证 padding 策略 |
truncation_validator() |
验证 truncation 策略 |
image_size_validator() |
验证图像尺寸格式 |
resampling_validator() |
验证重采样方法 |
device_validator() |
验证设备字符串 |
tensor_type_validator() |
验证张量类型 |
video_metadata_validator() |
验证视频元数据 |
3. 图像处理基础设施
3.1 ImageProcessingMixin(image_processing_base.py)
源文件:[image_processing_base.py](file:///workspace/src/transformers/image_processing_base.py)
模块职责
提供图像处理器的基础设施层,包括:
- 序列化/反序列化(
to_dict/from_dict/save_pretrained/from_pretrained) - 从 Hub 加载配置(
get_image_processor_dict) - 图像获取(
fetch_images)
核心类
BatchFeature --- 输出容器
python
class BatchFeature(BaseBatchFeature):
"""图像处理器输出的容器,继承自 feature_extraction_utils.BatchFeature"""
pass
ImageProcessingMixin --- 图像处理器混入类
python
class ImageProcessingMixin(PushToHubMixin):
_auto_class = None
def __init__(self, **kwargs):
# 移除过时的属性名
kwargs.pop("feature_extractor_type", None)
kwargs.pop("processor_class", None)
# 将所有 kwargs 设置为实例属性
for key, value in kwargs.items():
setattr(self, key, value)
from_pretrained 加载流程:
python
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# 1. 获取配置字典(优先从 processor_config.json 的嵌套配置,回退到独立配置文件)
image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
# 2. 从字典实例化
return cls.from_dict(image_processor_dict, **kwargs)
配置加载优先级 (get_image_processor_dict):
processor_config.json 中的嵌套 "image_processor" 键
↓ (若不存在)
preprocessor_config.json 独立文件
fetch_images --- 图像获取:
python
def fetch_images(self, image_url_or_urls):
if isinstance(image_url_or_urls, (list, tuple)):
return [self.fetch_images(x) for x in image_url_or_urls]
elif isinstance(image_url_or_urls, str):
return load_image(image_url_or_urls) # 从 URL 或路径加载
elif is_valid_image(image_url_or_urls):
return image_url_or_urls # 已经是有效图像,直接返回
3.2 BaseImageProcessor(image_processing_utils.py)
源文件:[image_processing_utils.py](file:///workspace/src/transformers/image_processing_utils.py)
模块职责
定义图像处理的标准流水线,包括:
- 参数验证和标准化
- 输入准备(将原始输入转换为后端格式)
- 调度到后端的
_preprocess方法 - 提供向后兼容的
rescale/normalize/center_crop方法
类继承体系
BaseImageProcessor (image_processing_utils.py)
├── TorchvisionBackend (image_processing_backends.py) ← GPU 加速
│ └── BaseVideoProcessor (video_processing_utils.py) ← 视频处理
│ └── ModelImageProcessor (各模型目录) ← 如 LlavaNextImageProcessor
└── PilBackend (image_processing_backends.py) ← CPU 便携
└── ModelImageProcessorPil (各模型目录) ← 如 CLIPImageProcessorPil
预处理流水线
python
# 调用链:
__call__() → preprocess() → _preprocess_image_like_inputs()
↓
_prepare_image_like_inputs() ← 逐图转换格式
↓
_preprocess() ← 批量操作(由后端实现)
preprocess 方法:
python
@auto_docstring
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
# 1. 类型验证
validate_typed_dict(self.valid_kwargs, kwargs)
# 2. 用实例属性填充未提供的 kwargs
for kwarg_name in self._valid_kwargs_names:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# 3. 标准化 kwargs(如将 size 转为 SizeDict)
kwargs = self._standardize_kwargs(**kwargs)
# 4. 验证预处理参数
self._validate_preprocess_kwargs(**kwargs)
# 5. 执行预处理
return self._preprocess_image_like_inputs(images, *args, **kwargs)
_prepare_image_like_inputs --- 逐图格式转换:
python
def _prepare_image_like_inputs(self, images, *args, expected_ndims=3, **kwargs):
# 1. 获取并展平图像结构
images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
# 2. 对每张图像调用 process_image(由后端实现)
process_image_partial = partial(self.process_image, *args, **kwargs)
has_nested_structure = len(images) > 0 and isinstance(images[0], list | tuple)
if has_nested_structure:
processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
else:
processed_images = [process_image_partial(img) for img in images]
return processed_images
_standardize_kwargs --- 参数标准化:
python
def _standardize_kwargs(self, size=None, crop_size=None, pad_size=None, ...):
# 将各种格式的 size 统一转为 SizeDict
if size is not None and not isinstance(size, SizeDict):
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
if crop_size is not None and not isinstance(crop_size, SizeDict):
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
# 将 list 类型的 mean/std 转为 tuple(可哈希,用于缓存)
if isinstance(image_mean, list):
image_mean = tuple(image_mean)
# ...
SizeDict 与尺寸规范
get_size_dict 函数将各种尺寸格式统一为标准字典:
python
# int + default_to_square=True → {"height": 224, "width": 224}
# int + default_to_square=False → {"shortest_edge": 224}
# (h, w) → {"height": h, "width": w}
# 合法的 key 组合:
VALID_SIZE_DICT_KEYS = (
{"height", "width"},
{"shortest_edge"},
{"shortest_edge", "longest_edge"},
{"longest_edge"},
{"max_height", "max_width"},
)
4. 图像处理后端(PIL/Torchvision)
源文件:[image_processing_backends.py](file:///workspace/src/transformers/image_processing_backends.py)
模块职责
提供两种图像处理后端实现,通过继承选择后端:
| 后端 | 数据格式 | 运行设备 | 特点 |
|---|---|---|---|
TorchvisionBackend |
torch.Tensor |
GPU/CPU | 批量化操作、融合优化、GPU 加速 |
PilBackend |
np.ndarray |
CPU | 便携、无 GPU 依赖、逐图处理 |
TorchvisionBackend
python
@requires(backends=("torch", "torchvision"))
class TorchvisionBackend(BaseImageProcessor):
"""GPU 加速的批量化图像处理后端"""
def __init__(self, **kwargs: Unpack[ImagesKwargs]):
super().__init__(**kwargs)
self._set_attributes(**kwargs) # 解析并设置所有 valid_kwargs 属性
process_image --- 单图格式转换
python
def process_image(self, image, do_convert_rgb=None, input_data_format=None, device=None, **kwargs):
image_type = get_image_type(image)
if do_convert_rgb:
image = self.convert_to_rgb(image)
# 统一转为 torch.Tensor,channels-first 格式
if image_type == ImageType.PIL:
image = tvF.pil_to_tensor(image) # PIL → Tensor (C, H, W)
elif image_type == ImageType.NUMPY:
image = torch.from_numpy(image).contiguous()
if image.ndim == 2:
image = image.unsqueeze(0) # 灰度图添加通道维
if input_data_format == ChannelDimension.LAST:
image = image.permute(2, 0, 1).contiguous() # HWC → CHW
if device is not None:
image = image.to(device)
return image
_preprocess --- 批量化处理流水线
python
def _preprocess(self, images, do_resize, size, resample, do_center_crop, crop_size,
do_rescale, rescale_factor, do_normalize, image_mean, image_std,
do_pad, pad_size, disable_grouping, return_tensors, **kwargs):
# 步骤1:按形状分组,批量化 resize
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, resample=resample)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# 步骤2:重新分组,执行 center_crop + 融合的 rescale_and_normalize
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
# 步骤3:padding
if do_pad:
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
融合优化 --- rescale_and_normalize
python
def rescale_and_normalize(self, images, do_rescale, rescale_factor, do_normalize, image_mean, image_std):
# 融合 rescale 和 normalize 为单次操作,减少 GPU 内存带宽
image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
do_normalize=do_normalize, image_mean=image_mean, image_std=image_std,
do_rescale=do_rescale, rescale_factor=rescale_factor, device=images.device,
)
if do_normalize:
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
elif do_rescale:
images = self.rescale(images, rescale_factor)
return images
融合原理 :normalize(rescale(x, factor), mean, std) 等价于 normalize(x, mean/factor, std/factor),因此可以跳过 rescale 步骤,直接在 normalize 时使用调整后的 mean 和 std。
group_images_by_shape --- 按形状分组批处理
python
def group_images_by_shape(images, *paired_inputs, disable_grouping=None, is_nested=False):
# 自动决策:CPU 上禁用分组(单张处理更快),GPU 上启用分组
if disable_grouping is None:
device = _get_device_from_images(images, is_nested)
disable_grouping = device == "cpu"
if disable_grouping:
# 每张图单独 unsqueeze 为 batch=1
return {key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)}, ...
# 按形状分组并 stack
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(...)
grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
return grouped_images, ...
PilBackend
python
@requires(backends=("vision",))
class PilBackend(BaseImageProcessor):
"""CPU 便携的图像处理后端,使用 NumPy 操作"""
def process_image(self, image, do_convert_rgb=None, input_data_format=None, **kwargs):
image_type = get_image_type(image)
if do_convert_rgb:
image = self.convert_to_rgb(image)
# 统一转为 np.ndarray,channels-first 格式
if image_type == ImageType.PIL:
image = np.array(image)
if image.ndim >= 3:
input_data_format = ChannelDimension.LAST
elif image_type == ImageType.TORCH:
image = image.numpy()
if image.ndim == 2:
image = np.expand_dims(image, axis=0)
if input_data_format == ChannelDimension.LAST:
image = np.transpose(image, (2, 0, 1)) # HWC → CHW
return image
PilBackend 的 _preprocess:逐图处理,无批量化优化:
python
def _preprocess(self, images, do_resize, size, resample, ...):
processed_images = []
for image in images:
if do_resize:
image = self.resize(image=image, size=size, resample=resample)
if do_center_crop:
image = self.center_crop(image, crop_size)
if do_rescale:
image = self.rescale(image, rescale_factor)
if do_normalize:
image = self.normalize(image, image_mean, image_std)
processed_images.append(image)
if do_pad:
processed_images = self.pad(processed_images, pad_size=pad_size)
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
5. 图像变换
源文件:[image_transforms.py](file:///workspace/src/transformers/image_transforms.py)
模块职责
提供与后端无关的底层图像变换函数 ,基于 NumPy 实现(PIL 后端直接使用,Torchvision 后端使用对应的 Torch 实现)。所有函数都支持 ChannelDimension 格式感知。
核心函数
to_channel_dimension_format --- 通道维度转换
python
def to_channel_dimension_format(image, channel_dim, input_channel_dim=None):
"""将图像转换为指定的通道维度格式(channels_first 或 channels_last)"""
if input_channel_dim is None:
input_channel_dim = infer_channel_dimension_format(image)
if input_channel_dim == target_channel_dim:
return image
# 通过 transpose 交换最后三个维度
if target_channel_dim == ChannelDimension.FIRST:
axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2]
image = image.transpose(axes)
# ...
rescale --- 像素值缩放
python
def rescale(image, scale, data_format=None, dtype=np.float32, input_data_format=None):
"""将像素值乘以缩放因子(通常 1/255,从 [0,255] 缩放到 [0,1])"""
rescaled_image = image.astype(np.float64) * scale # 先上转避免精度损失
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image
normalize --- 标准化
python
def normalize(image, mean, std, data_format=None, input_data_format=None):
"""标准化:image = (image - mean) / std"""
channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
num_channels = image.shape[channel_axis]
# 将标量 mean/std 扩展为与通道数匹配的数组
if isinstance(mean, Collection):
if len(mean) != num_channels:
raise ValueError(...)
else:
mean = [mean] * num_channels
mean = np.array(mean, dtype=image.dtype)
# ... 同理处理 std
if input_data_format == ChannelDimension.LAST:
image = (image - mean) / std
else:
image = ((image.T - mean) / std).T # channels-first 需要转置广播
resize --- 缩放
python
def resize(image, size, resample=None, reducing_gap=None, data_format=None,
return_numpy=True, input_data_format=None):
"""使用 PIL 库缩放图像(即使输入是 NumPy 数组也先转 PIL 再缩放)"""
# 检测是否需要先 rescale 再转 PIL
do_rescale = False
if not isinstance(image, PIL.Image.Image):
do_rescale = _rescale_for_pil_conversion(image)
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
# PIL 缩放(注意 PIL 的尺寸顺序是 width, height)
resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
if return_numpy:
resized_image = np.array(resized_image)
# 恢复通道维度格式和原始值范围
resized_image = to_channel_dimension_format(resized_image, data_format, ...)
resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
return resized_image
center_crop --- 中心裁剪
python
def center_crop(image, size, data_format=None, input_data_format=None):
"""中心裁剪,若图像小于目标尺寸则先零填充再裁剪"""
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
# 计算裁剪边界
top = (orig_height - crop_height) // 2
# ...
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
image = image[..., top:bottom, left:right] # 直接裁剪
else:
# 图像太小,先零填充再裁剪
new_image = np.zeros_like(image, shape=new_shape)
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
new_image = new_image[..., max(0, top):min(new_height, bottom), ...]
pad --- 填充
python
def pad(image, padding, mode=PaddingMode.CONSTANT, constant_values=0.0, ...):
"""使用 np.pad 填充图像,支持 constant/reflect/replicate/symmetric 模式"""
# 根据 data_format 展开填充参数
padding = _expand_for_data_format(padding)
if mode == PaddingMode.CONSTANT:
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
elif mode == PaddingMode.REFLECT:
image = np.pad(image, padding, mode="reflect")
# ...
辅助函数
| 函数 | 用途 |
|---|---|
get_resize_output_image_size |
计算缩放后的目标尺寸 |
get_size_with_aspect_ratio |
保持宽高比计算缩放尺寸 |
select_best_resolution |
从候选分辨率中选择最佳匹配 |
divide_to_patches |
将图像分割为 patch |
group_images_by_shape |
按形状分组图像(用于批量化) |
reorder_images |
恢复图像的原始顺序 |
center_to_corners_format |
边界框格式转换(中心→角点) |
rgb_to_id / id_to_rgb |
颜色与 ID 互转(全景分割) |
6. 图像工具
源文件:[image_utils.py](file:///workspace/src/transformers/image_utils.py)
模块职责
提供图像处理系统的基础工具层,包括:
- 类型定义(
ImageInput、ChannelDimension、ImageType) - 图像验证和分类
- 图像加载(URL、本地路径、base64)
- 通道维度推断
- 图像列表结构规范化
核心类型定义
python
# 图像输入的联合类型
ImageInput = Union[
"PIL.Image.Image", np.ndarray, "torch.Tensor",
list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
]
# 通道维度枚举
class ChannelDimension(ExplicitEnum):
FIRST = "channels_first" # (C, H, W)
LAST = "channels_last" # (H, W, C)
# 图像类型枚举
class ImageType(ExplicitEnum):
PIL = "pillow"
TORCH = "torch"
NUMPY = "numpy"
核心函数
infer_channel_dimension_format --- 推断通道维度
python
def infer_channel_dimension_format(image, num_channels=None):
"""推断图像的通道维度位置(first 或 last)"""
num_channels = num_channels if num_channels is not None else (1, 3)
if image.ndim == 3:
first_dim, last_dim = 0, 2
elif image.ndim == 4:
first_dim, last_dim = 1, 3
# 检查第一个维度和最后一个维度哪个匹配通道数
if image.shape[first_dim] in num_channels:
return ChannelDimension.FIRST
elif image.shape[last_dim] in num_channels:
return ChannelDimension.LAST
make_list_of_images / make_flat_list_of_images --- 结构规范化
python
def make_list_of_images(images, expected_ndims=3):
"""确保输出为图像列表:单图 → [图],批量 → [图1, 图2, ...]"""
if is_batched(images):
return images
if is_pil_image(images):
return [images]
if is_valid_image(images):
if images.ndim == expected_ndims + 1:
images = list(images) # 4D 批量 → 列表
elif images.ndim == expected_ndims:
images = [images] # 3D 单图 → 列表
def make_flat_list_of_images(images, expected_ndims=3):
"""确保输出为扁平的图像列表:嵌套列表 → 扁平列表"""
if isinstance(images, (list, tuple)) and all(isinstance(i, (list, tuple)) for i in images):
return [img for img_list in images for img in img_list]
load_image --- 图像加载
支持从 URL、本地路径、base64 字符串加载图像,内部使用 httpx 获取远程资源,PIL 打开图像文件。
7. 特征提取系统
7.1 FeatureExtractionMixin(feature_extraction_utils.py)
源文件:[feature_extraction_utils.py](file:///workspace/src/transformers/feature_extraction_utils.py)
模块职责
提供特征提取器的基础设施 ,与 ImageProcessingMixin 结构类似:
BatchFeature--- 所有处理器输出的统一容器FeatureExtractionMixin--- 保存/加载/序列化的混入类
BatchFeature --- 核心输出容器
python
class BatchFeature(UserDict):
"""所有处理器 __call__ 方法的输出容器,本质是一个字典"""
def __init__(self, data=None, tensor_type=None, skip_tensor_conversion=None):
super().__init__(data)
self.skip_tensor_conversion = skip_tensor_conversion
self.convert_to_tensors(tensor_type=tensor_type) # 自动转换张量类型
def convert_to_tensors(self, tensor_type=None, skip_tensor_conversion=None):
"""将内部数据转换为指定类型的张量(PyTorch 或 NumPy)"""
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
for key, value in self.items():
if skip_tensor_conversion and key in skip_tensor_conversion:
continue
if not _is_tensor_or_array_like(value):
continue
if not is_tensor(value):
self[key] = as_tensor(value)
def to(self, *args, **kwargs):
"""将所有张量移到指定设备(PyTorch only)"""
# 只转换浮点张量的 dtype,其他张量只移设备
def maybe_to(v):
if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
return v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
return v.to(device=device, non_blocking=non_blocking)
elif isinstance(v, (list, tuple)):
return type(v)(maybe_to(item) for item in v)
else:
return v
self.data = {k: maybe_to(v) for k, v in self.items()}
7.2 SequenceFeatureExtractor(feature_extraction_sequence_utils.py)
源文件:[feature_extraction_sequence_utils.py](file:///workspace/src/transformers/feature_extraction_sequence_utils.py)
模块职责
为序列型特征提取器(主要用于音频)提供:
- Padding/Truncation 策略
- 注意力掩码生成
- 批量化处理
python
class SequenceFeatureExtractor(FeatureExtractionMixin):
def __init__(self, feature_size, sampling_rate, padding_value, **kwargs):
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.padding_side = kwargs.pop("padding_side", "right")
self.return_attention_mask = kwargs.pop("return_attention_mask", True)
def pad(self, processed_features, padding=True, max_length=None,
truncation=False, pad_to_multiple_of=None, ...):
"""对特征序列进行 padding,支持多种策略"""
# 1. 将 list[dict] 转为 dict[list](兼容 DataLoader collate_fn)
# 2. 统一转为 NumPy 数组
# 3. 执行 truncation
# 4. 根据 padding 策略执行 padding
# - LONGEST: pad 到批次中最长序列
# - MAX_LENGTH: pad 到 max_length
# 5. 生成 attention_mask
8. 音频工具
源文件:[audio_utils.py](file:///workspace/src/transformers/audio_utils.py)
模块职责
提供音频处理的底层工具:
- 音频加载(支持多种后端)
- 音频格式转换
- 频谱分析工具(Mel 滤波器、STFT 等)
- 音频验证和结构规范化
核心类型和函数
AudioInput 类型
python
AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]]
load_audio --- 音频加载
python
def load_audio(audio, sampling_rate=16000, timeout=None):
"""加载音频为 NumPy 数组,支持 URL 和本地路径"""
if isinstance(audio, str):
# 优先使用 torchcodec(更高效),回退到 librosa
if is_torchcodec_available() and version.parse("0.3.0") <= TORCHCODEC_VERSION:
audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate, timeout=timeout)
else:
audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
elif not isinstance(audio, np.ndarray):
raise TypeError(...)
return audio
音频加载后端:
| 后端 | 依赖 | 特点 |
|---|---|---|
torchcodec |
torchcodec ≥ 0.3.0 | 推荐,高效 |
librosa |
librosa + soxr | 传统方案,广泛兼容 |
Mel 频谱工具
python
def hertz_to_mel(freq, mel_scale="htk"):
"""Hz → Mel 频率转换,支持 htk/kaldi/slaney 三种尺度"""
if mel_scale == "htk":
return 2595.0 * np.log10(1.0 + (freq / 700.0))
# ...
def mel_to_hertz(mels, mel_scale="htk"):
"""Mel → Hz 频率转换"""
if mel_scale == "htk":
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
# ...
def _create_triangular_filter_bank(fft_freqs, filter_freqs):
"""创建三角滤波器组(用于 Mel 频谱图)"""
filter_diff = np.diff(filter_freqs)
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
up_slopes = slopes[:, 2:] / filter_diff[1:]
return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
9. 视频处理系统
9.1 video_utils.py --- 视频工具
源文件:[video_utils.py](file:///workspace/src/transformers/video_utils.py)
模块职责
提供视频处理的底层工具:
- 视频输入类型定义和验证
- 视频元数据(
VideoMetadata) - 视频解码(支持 5 种后端)
- 帧采样策略
- 视频结构规范化
VideoInput 类型
python
VideoInput = Union[
list["PIL.Image.Image"], # 帧列表
np.ndarray, # 4D 数组 (T, H, W, C) 或 (T, C, H, W)
"torch.Tensor", # 4D 张量
list[np.ndarray], list["torch.Tensor"],
list[list["PIL.Image.Image"]], # 批量帧列表
URL, list[URL], list[list[URL]], # URL
Path, list[Path], list[list[Path]], # 本地路径
]
VideoMetadata --- 视频元数据
python
@dataclass
class VideoMetadata(Mapping):
total_num_frames: int
fps: float | None = None
width: int | None = None
height: int | None = None
duration: float | None = None
video_backend: str | None = None
frames_indices: list[int] | None = None
@property
def timestamps(self) -> list[float]:
"""采样帧的时间戳(秒)"""
return [frame_idx / self.fps for frame_idx in self.frames_indices]
@property
def sampled_fps(self) -> float:
"""采样后的实际 FPS"""
return len(self.frames_indices) / self.total_num_frames * self.fps
视频解码后端
python
VIDEO_DECODERS = {
"decord": read_video_decord, # Decord 后端
"opencv": read_video_opencv, # OpenCV 后端
"pyav": read_video_pyav, # PyAV 后端(默认)
"torchvision": read_video_torchvision, # Torchvision 后端(已弃用)
"torchcodec": read_video_torchcodec, # Torchcodec 后端(推荐)
}
load_video --- 视频加载
python
def load_video(video, num_frames=None, fps=None, backend="pyav",
sample_indices_fn=None, **kwargs):
"""加载视频为 NumPy 数组"""
# 1. 创建采样函数(默认均匀采样)
if sample_indices_fn is None:
def sample_indices_fn_func(metadata, **fn_kwargs):
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
sample_indices_fn = sample_indices_fn_func
# 2. 如果已经是数组/PIL帧,直接返回
if not isinstance(video, str):
return video, [None] * len(video)
# 3. 处理 YouTube URL(需要 yt_dlp)
# 4. 处理 HTTP URL(使用 httpx 下载)
# 5. 使用指定后端解码
video_decoder = VIDEO_DECODERS[backend]
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
return video, metadata
帧采样策略
python
def default_sample_indices_fn(metadata, num_frames=None, fps=None, **kwargs):
"""默认均匀采样策略"""
total_num_frames = metadata.total_num_frames
video_fps = metadata.fps
# 如果指定了 fps,从 fps 计算 num_frames
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
else:
indices = np.arange(0, total_num_frames, dtype=int)
return indices
9.2 video_processing_utils.py --- 视频处理器
源文件:[video_processing_utils.py](file:///workspace/src/transformers/video_processing_utils.py)
模块职责
提供视频处理器的完整实现,继承自 TorchvisionBackend,增加了:
- 视频解码和帧采样
- 视频元数据管理
- 视频专用的 RGB 转换
BaseVideoProcessor
python
@requires(backends=("vision", "torchvision"))
class BaseVideoProcessor(TorchvisionBackend):
model_input_names = ["pixel_values_videos"] # 注意:与图像的 "pixel_values" 不同
valid_kwargs = VideosKwargs
def preprocess(self, videos, **kwargs):
# 1. 参数验证
validate_kwargs(...)
validate_typed_dict(self.valid_kwargs, kwargs)
# 2. 填充默认值
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# 3. 解码和帧采样
sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
videos, video_metadata = self._decode_and_sample_videos(
videos, video_metadata=video_metadata,
do_sample_frames=do_sample_frames, sample_indices_fn=sample_indices_fn,
)
# 4. 准备输入视频(转为 Tensor、调整通道维度)
videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device)
# 5. 标准化和验证 kwargs
kwargs = self._standardize_kwargs(**kwargs)
self._validate_preprocess_kwargs(**kwargs)
# 6. 执行预处理(复用 TorchvisionBackend 的 _preprocess)
preprocessed_videos = self._preprocess(videos=videos, **kwargs)
if return_metadata:
preprocessed_videos["video_metadata"] = video_metadata
return preprocessed_videos
_decode_and_sample_videos --- 解码和采样:
python
def _decode_and_sample_videos(self, videos, video_metadata, do_sample_frames, sample_indices_fn):
videos = make_batched_videos(videos)
video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
if is_valid_video(videos[0]) and do_sample_frames:
# 已加载的视频数组,直接按索引采样帧
sampled_videos = []
for video, metadata in zip(videos, video_metadata):
indices = sample_indices_fn(metadata=metadata)
metadata.frames_indices = indices
sampled_videos.append(video[indices])
videos = sampled_videos
elif not is_valid_video(videos[0]):
# URL/路径,需要先解码
if isinstance(videos[0], list):
# 图像 URL 列表 → 逐帧加载并 stack
videos = [torch.stack([self.process_image(image) for image in images])
for images in self.fetch_images(videos)]
else:
videos, video_metadata = self.fetch_videos(videos, sample_indices_fn=sample_indices_fn)
return videos, video_metadata
10. 视觉工具
源文件:[vision_utils.py](file:///workspace/src/transformers/vision_utils.py)
模块职责
提供视觉编码器的预计算工具函数 ,主要用于解决 torch.compile / torch.export 追踪中的动态计算问题。这些函数计算视觉编码器所需的辅助张量(如累积序列长度、位置 ID、窗口注意力索引等),可以在模型编译前预计算。
核心函数
get_vision_cu_seqlens --- 累积序列长度
python
def get_vision_cu_seqlens(grid_thw, kwargs=None):
"""从 grid_thw 计算累积序列长度,用于 Flash Attention 等变长序列操作"""
if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None:
return cu_seqlens # 如果已预计算,直接返回
# grid_thw: (num_images, 3) --- 每个图像的 (temporal, height, width)
# 计算每个图像的 patch 数 = h * w,按 temporal 维度重复后累加
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(...)
return F.pad(cu_seqlens, (1, 0), value=0) # 前补 0
get_vision_position_ids --- 位置 ID
python
def get_vision_position_ids(grid_thw, spatial_merge_size, kwargs=None):
"""计算视觉 Rotary Embedding 的 (row, col) 位置 ID"""
# 对每个图像,生成 h×w 的网格位置,考虑 spatial merge
for (t, h, w), merge_size in zip(grid_thw.tolist(), spatial_merge_size.tolist()):
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(h // merge_size, merge_size, w // merge_size, merge_size)
hpos_ids = hpos_ids.transpose(1, 2).flatten()
# ... 同理计算 wpos_ids
position_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
return torch.cat(position_ids, dim=0)
get_vision_window_index --- 窗口注意力索引
python
def get_vision_window_index(grid_thw, spatial_merge_size, window_size, patch_size, kwargs=None):
"""计算窗口注意力的重排序索引和窗口边界"""
# 将图像划分为不重叠的窗口,生成窗口内的 token 重排索引
vit_merger_window_size = window_size // spatial_merge_size // patch_size
for grid_t, grid_h, grid_w in grid_thw.tolist():
# 创建索引网格,padding 到窗口大小的整数倍
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(...)
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
# 重排为窗口格式
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(...)
return window_index, cu_window_seqlens
get_vision_bilinear_indices_and_weights --- 双线性插值
python
def get_vision_bilinear_indices_and_weights(grid_thw, num_grid_per_side, spatial_merge_size, kwargs=None):
"""计算双线性插值的索引和权重,用于位置编码的二维插值"""
# 对每个图像,在 h 和 w 方向上生成插值网格
for t, h, w in grid_thw.tolist():
h_grid = torch.linspace(0, side - 1, h)
w_grid = torch.linspace(0, side - 1, w)
# 计算四个角的索引和权重
corner_indices = [h_floor + w_floor, h_floor + w_ceil, h_ceil + w_floor, h_ceil + w_ceil]
corner_weights = [(1-h_frac)*(1-w_frac), (1-h_frac)*w_frac, h_frac*(1-w_frac), h_frac*w_frac]
return bilinear_indices, bilinear_weights
预计算模式 :所有 get_* 函数都接受 kwargs 参数,如果 kwargs 中已包含预计算结果,则直接返回而不重新计算。这使得视觉编码器可以在 torch.compile 之前预计算这些动态张量。
11. 模块间关系与数据流
完整数据流
以多模态输入 processor(text="描述", images=img, videos=vid, audio=aud) 为例:
用户调用 processor(text, images, videos, audio, **kwargs)
│
▼
ProcessorMixin.__call__()
│
├── _merge_kwargs() ──── 合并所有来源的参数
│ ├── ProcessingKwargs._defaults(最低优先级)
│ ├── tokenizer.init_kwargs
│ ├── common_kwargs
│ └── 调用时传入的 kwargs(最高优先级)
│
├── tokenizer(text, **text_kwargs) ──── 文本处理
│ └── 返回 {"input_ids": ..., "attention_mask": ...}
│
├── image_processor(images, **images_kwargs) ──── 图像处理
│ ├── BaseImageProcessor.preprocess()
│ │ ├── validate_typed_dict() ──── 类型验证
│ │ ├── _standardize_kwargs() ──── 参数标准化
│ │ └── _preprocess_image_like_inputs()
│ │ ├── _prepare_image_like_inputs() ──── 逐图格式转换
│ │ │ └── process_image() ──── [TorchvisionBackend/PilBackend]
│ │ └── _preprocess() ──── 批量处理
│ │ ├── group_images_by_shape() ──── 按形状分组
│ │ ├── resize() ──── 缩放
│ │ ├── center_crop() ──── 中心裁剪
│ │ ├── rescale_and_normalize() ──── 融合缩放+标准化
│ │ ├── pad() ──── 填充
│ │ └── reorder_images() ──── 恢复原始顺序
│ └── 返回 {"pixel_values": ...}
│
├── video_processor(videos, **videos_kwargs) ──── 视频处理
│ ├── BaseVideoProcessor.preprocess()
│ │ ├── _decode_and_sample_videos() ──── 解码+帧采样
│ │ │ ├── make_batched_videos()
│ │ │ ├── sample_frames() / load_video()
│ │ │ └── _prepare_input_videos()
│ │ └── _preprocess() ──── 复用 TorchvisionBackend
│ └── 返回 {"pixel_values_videos": ..., "video_metadata": ...}
│
├── feature_extractor(audio, **audio_kwargs) ──── 音频处理
│ ├── SequenceFeatureExtractor.__call__()
│ │ ├── load_audio() ──── 加载音频
│ │ └── pad() ──── padding + attention_mask
│ └── 返回 {"input_values": ..., "attention_mask": ...}
│
└── BatchFeature({
"input_ids": ..., ← 来自 tokenizer
"attention_mask": ..., ← 来自 tokenizer
"pixel_values": ..., ← 来自 image_processor
"pixel_values_videos": ...,← 来自 video_processor
"input_values": ..., ← 来自 feature_extractor
})
模块依赖关系图
processing_utils.py
├── image_utils.py (ImageInput, ChannelDimension)
├── video_utils.py (VideoInput, VideoMetadata)
├── audio_utils.py (AudioInput, load_audio)
├── feature_extraction_utils.py (BatchFeature)
├── tokenization_utils_base.py (PaddingStrategy, TruncationStrategy)
└── utils/type_validators.py (验证器)
image_processing_utils.py (BaseImageProcessor)
├── image_processing_base.py (ImageProcessingMixin, BatchFeature)
├── image_transforms.py (resize, normalize, rescale, center_crop, pad)
├── image_utils.py (ImageInput, ChannelDimension, SizeDict)
└── processing_utils.py (ImagesKwargs)
image_processing_backends.py (TorchvisionBackend, PilBackend)
├── image_processing_utils.py (BaseImageProcessor)
├── image_transforms.py (各种变换函数)
└── image_utils.py (ImageInput, ImageType)
video_processing_utils.py (BaseVideoProcessor)
├── image_processing_backends.py (TorchvisionBackend)
├── video_utils.py (VideoInput, VideoMetadata, load_video)
└── processing_utils.py (VideosKwargs)
vision_utils.py (独立模块,仅依赖 torch)
└── 被视觉编码器模型代码调用
feature_extraction_utils.py (BatchFeature, FeatureExtractionMixin)
└── 独立基础设施
feature_extraction_sequence_utils.py (SequenceFeatureExtractor)
├── feature_extraction_utils.py (BatchFeature, FeatureExtractionMixin)
└── audio_utils.py (load_audio, is_valid_audio)
audio_utils.py (独立工具模块)
└── 依赖 soundfile, librosa, torchcodec(可选)
关键设计模式总结
-
Mixin 模式 :
ProcessorMixin、ImageProcessingMixin、FeatureExtractionMixin都继承自PushToHubMixin,通过混入提供保存/加载能力。 -
策略模式:图像处理后端(Torchvision/PIL)通过继承选择,而非运行时切换,确保类型安全和性能优化。
-
TypedDict 驱动的参数系统 :使用 Python 的
TypedDict+Annotated实现类型安全的参数传递和验证,替代了传统的松散**kwargs。 -
优先级合并 :
_merge_kwargs实现了四级优先级的参数合并,确保用户调用时的参数具有最高优先级,同时支持模型级别的默认值。 -
批量化优化 :Torchvision 后端通过
group_images_by_shape+reorder_images实现自动批量化,相同形状的图像 stack 后批量处理,不同形状的分组处理。 -
融合优化 :
rescale_and_normalize将两步操作融合为一步,减少 GPU 内存带宽占用。 -
预计算模式 :
vision_utils.py中的函数支持从kwargs中读取预计算结果,解决torch.compile追踪动态计算的问题。