08-Hugging Face Transformers 多模态处理系统深度分析

Hugging Face Transformers 多模态处理系统深度分析

相关文章:
Hugging Face Transformers 源码全景解读
01-Hugging Face Transformers 核心基础设施深度分析
02-Hugging Face Transformers 配置系统深度分析
03-Hugging Face Transformers 模型系统深度分析
04-Hugging Face Transformers 注意力与掩码系统深度分析
05-Hugging Face Transformers 缓存系统深度分析
06-Hugging Face Transformers 生成系统深度分析
07-Hugging Face Transformers 分词器系统深度分析

目录

  1. 系统架构总览
  2. [ProcessorMixin --- 多模态处理器基类](#ProcessorMixin — 多模态处理器基类)
  3. [TypedDict kwargs 系统](#TypedDict kwargs 系统)
  4. 图像处理基础设施
  5. 图像处理后端(PIL/Torchvision)
  6. 图像变换
  7. 图像工具
  8. 特征提取系统
  9. 音频工具
  10. 视频处理系统
  11. 视觉工具
  12. 模块间关系与数据流

多模态处理系统架构总览

工具层
变换层
后端层
子处理器层
统一入口层
ProcessorMixin

PushToHubMixin
call 多模态统一入口
Tokenizer

文本处理
ImageProcessor

图像处理
VideoProcessor

视频处理
FeatureExtractor

特征提取
PilBackend

CPU后端
TorchvisionBackend

GPU后端
Resize

Rescale

Normalize

Pad
image_utils

video_utils

audio_utils

vision_utils


系统架构总览

Transformers 的多模态处理系统采用分层架构 设计,以 ProcessorMixin 为核心调度器,将不同模态(文本、图像、视频、音频)的处理逻辑解耦为独立的子处理器。整体架构如下:

复制代码
                         ┌─────────────────────┐
                         │    ProcessorMixin    │  ← 多模态统一入口
                         │  (__call__ / _merge_ │
                         │     kwargs)          │
                         └──────┬──────────────┘
               ┌────────────────┼────────────────┐
               │                │                │
        ┌──────▼──────┐ ┌──────▼──────┐ ┌───────▼──────┐
        │  Tokenizer   │ │ ImageProc   │ │ VideoProc    │  ...
        │  (文本处理)  │ │ (图像处理)  │ │ (视频处理)   │
        └─────────────┘ └──────┬──────┘ └───────┬──────┘
                               │                │
                    ┌──────────▼──────────┐     │
                    │  Backend 层          │     │
                    │  TorchvisionBackend │     │
                    │  PilBackend         │     │
                    └──────────┬──────────┘     │
                               │                │
                    ┌──────────▼──────────┐     │
                    │  Transform 层        │     │
                    │  resize / normalize │     │
                    │  rescale / pad      │     │
                    └──────────┬──────────┘     │
                               │                │
                    ┌──────────▼──────────┐     │
                    │  Utils 层            │     │
                    │  image_utils        │     │
                    │  video_utils        │◄────┘
                    │  audio_utils        │
                    └─────────────────────┘

核心设计原则

  • 模态解耦:每种模态有独立的处理器,通过 ProcessorMixin 统一调度
  • 后端可替换:图像处理支持 PIL(CPU)和 Torchvision(GPU)两种后端
  • TypedDict 驱动的参数系统:使用 Python TypedDict 实现类型安全的参数传递
  • 优先级合并:kwargs 按优先级从高到低合并(调用时 > 模态专用 > 初始化时 > 默认值)

1. ProcessorMixin --- 多模态处理器基类

源文件:[processing_utils.py](file:///workspace/src/transformers/processing_utils.py)

模块职责

ProcessorMixin 是所有多模态处理器的基类,负责:

  1. 统一管理各模态的子处理器(tokenizer、image_processor、video_processor、feature_extractor)
  2. 将用户输入按模态分发到对应子处理器
  3. 合并各模态处理结果为统一的 BatchFeature 输出
  4. 管理聊天模板(chat template)和对话式多模态输入
  5. 提供保存/加载(save_pretrained/from_pretrained)功能

核心类/函数

ProcessorMixin 类
python 复制代码
class ProcessorMixin(PushToHubMixin):
    # 子处理器属性,由子类通过 __init__ 的 attributes 动态设置
    tokenizer: Any
    feature_extractor: Any
    image_processor: Any
    video_processor: Any
    chat_template: str | dict[str, str] | None

    valid_processor_kwargs = ProcessingKwargs  # 默认的 kwargs 类型定义
__call__ 方法 --- 多模态统一入口
python 复制代码
def __call__(
    self,
    images: ImageInput | None = None,
    text: TextInput | ... | None = None,
    videos: VideoInput | None = None,
    audio: AudioInput | None = None,
    **kwargs: Unpack[ProcessingKwargs],
):
    # 1. 合并所有来源的 kwargs(调用时 > 模态专用 > 初始化时 > 默认值)
    kwargs = self._merge_kwargs(
        self.valid_processor_kwargs,
        tokenizer_init_kwargs=self.tokenizer.init_kwargs if hasattr(self, "tokenizer") else {},
        **kwargs,
    )

    # 2. 建立模态到子处理器和 kwargs 的映射
    attribute_to_kwargs = {
        "tokenizer": (text, "text_kwargs"),
        "image_processor": (images, "images_kwargs"),
        "video_processor": (videos, "videos_kwargs"),
        "feature_extractor": (audio, "audio_kwargs"),
    }

    # 3. 遍历所有子处理器,分发输入和对应的 kwargs
    outputs = {}
    for attribute_name in self.get_attributes():
        attribute = getattr(self, attribute_name, None)
        input_data, input_kwargs = attribute_to_kwargs[attribute_name]
        if input_data is not None and attribute is not None:
            attribute_output = attribute(input_data, **kwargs[input_kwargs])
            outputs.update(attribute_output)

    # 4. 合并所有子处理器输出为 BatchFeature
    return BatchFeature(outputs)

设计要点

  • 音频模态通过 feature_extractor 处理(历史原因,feature_extractor 同时负责音频特征提取)
  • 各子处理器输出合并时使用 dict.update(),因此 key 不能冲突
  • Unpack[ProcessingKwargs] 提供了 IDE 友好的类型提示
_merge_kwargs 方法 --- 参数合并核心

这是整个多模态系统中最复杂的部分,负责将来自不同来源的参数按优先级合并:

python 复制代码
def _merge_kwargs(
    self,
    ModelProcessorKwargs: ProcessingKwargs,
    tokenizer_init_kwargs: dict | None = None,
    **kwargs,
) -> dict[str, dict]:

参数优先级(从高到低):

优先级 来源 示例
1(最高) 调用时直接传入的 flat kwargs processor(..., padding="max_length")
2 调用时传入的模态专用 kwargs processor(..., text_kwargs={"padding": "max_length"})
3 子处理器初始化时的 kwargs tokenizer = Tokenizer(..., padding="max_length")
4(最低) ProcessingKwargs 中的 _defaults _defaults = {"text_kwargs": {"padding": "max_length"}}

合并流程

python 复制代码
# 步骤1:初始化输出字典和默认值字典
output_kwargs = {"text_kwargs": {}, "images_kwargs": {}, "audio_kwargs": {}, "videos_kwargs": {}}
default_kwargs = {"text_kwargs": {}, "images_kwargs": {}, "audio_kwargs": {}, "videos_kwargs": {}}

# 步骤2:从 ModelProcessorKwargs._defaults 填充默认值
for modality in default_kwargs:
    default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
    # 同时收集子处理器的 valid_kwargs 以确保一致性
    modality_valid_kwargs = set(ModelProcessorKwargs.__annotations__[modality].__annotations__)
    if modality in map_preprocessor_kwargs:
        preprocessor = getattr(self, map_preprocessor_kwargs[modality], None)
        preprocessor_valid_kwargs = getattr(preprocessor, "valid_kwargs", None)
        modality_valid_kwargs.update(set(preprocessor_valid_kwargs.__annotations__))

# 步骤3:用 tokenizer 初始化 kwargs 覆盖默认值
for modality_key in modality_valid_kwargs:
    if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs:
        default_kwargs[modality][modality_key] = value

# 步骤4:处理 common_kwargs(跨模态共享参数)
common_kwargs = ModelProcessorKwargs._defaults.get("common_kwargs", {})
common_kwargs.update(kwargs.get("common_kwargs", {}))
if common_kwargs:
    for kwarg in output_kwargs.values():
        kwarg.update(common_kwargs)

# 步骤5:用调用时传入的 kwargs 覆盖(最高优先级)
# ... 按模态分发 flat kwargs 和嵌套 kwargs

# 步骤6:类型验证
for key, typed_dict_obj in ModelProcessorKwargs.__annotations__.items():
    validate_typed_dict(typed_dict_obj, output_kwargs[key])
apply_chat_template 方法 --- 对话式多模态处理
python 复制代码
def apply_chat_template(
    self,
    conversation: list[dict[str, str]] | list[list[dict[str, str]]],
    chat_template: str | None = None,
    ...
) -> str:

此方法将 OpenAI 风格的多模态对话转换为模型输入:

  1. 模板渲染:使用 Jinja2 模板将对话格式化为文本
  2. 多模态提取:从对话内容中提取图像/视频/音频的 URL 或路径
  3. 输入加载 :调用 load_audiofetch_images 等加载实际数据
  4. 格式归一化 :将 OpenAI 的 image_url 格式转换为 HuggingFace 的 image 格式
  5. 统一处理 :最终调用 self() 将所有模态一起处理
get_attributes 方法 --- 动态属性发现
python 复制代码
@classmethod
def get_attributes(cls):
    # 从 __init__ 签名中提取所有模态属性
    args_in_init = inspect.signature(cls.__init__).parameters.keys()
    attributes = []
    for sub_processor_type in args_in_init:
        if any(modality in sub_processor_type for modality in MODALITY_TO_AUTOPROCESSOR_MAPPING.keys()):
            attributes.append(sub_processor_type)
    # 兼容旧版:从 <attribute>_class 类属性推断
    if not attributes:
        for attribute_name, value in cls.__dict__.items():
            if value is not None and attribute_name.endswith("_class"):
                inferred_attribute = attribute_name[:-len("_class")]
                attributes.append(inferred_attribute)
    return attributes
_LazyAutoProcessorMapping --- 延迟加载映射
python 复制代码
class _LazyAutoProcessorMapping(dict):
    _MAPPING_NAMES = {
        "image_processor": ("transformers.models.auto.image_processing_auto", "AutoImageProcessor"),
        "video_processor": ("transformers.models.auto.video_processing_auto", "AutoVideoProcessor"),
        "feature_extractor": ("transformers.models.auto.feature_extraction_auto", "AutoFeatureExtractor"),
        "audio_processor": ("transformers.models.auto.feature_extraction_auto", "AutoFeatureExtractor"),
        "tokenizer": ("transformers.models.auto.tokenization_auto", "AutoTokenizer"),
    }

通过延迟导入避免循环依赖------只有在实际访问时才导入对应的 Auto 类。

与其他模块的关系

  • 依赖 image_utils(ImageInput)、video_utils(VideoInput)、audio_utils(AudioInput)定义输入类型
  • 依赖 feature_extraction_utils(BatchFeature)作为输出容器
  • 依赖 tokenization_utils_base(PaddingStrategy、TruncationStrategy 等)处理文本参数
  • 被依赖 于所有具体模型的 Processor 类(如 LlavaProcessor、Qwen2VLProcessor 等)

2. TypedDict kwargs 系统

源文件:[processing_utils.py](file:///workspace/src/transformers/processing_utils.py)(L162-L475)

模块职责

TypedDict kwargs 系统是 Transformers v4.47+ 引入的类型安全参数传递机制 ,替代了之前松散的 **kwargs 传递方式,提供:

  1. IDE 自动补全和类型检查
  2. 运行时参数验证(通过 validate_typed_dict
  3. 按模态自动分发参数
  4. 自定义验证器(如 positive_intpadding_validator

核心类

TextKwargs --- 文本处理参数
python 复制代码
class TextKwargs(TypedDict, total=False):
    text_pair: TextInput | PreTokenizedInput | list[TextInput] | ... | None
    add_special_tokens: bool | None
    padding: Annotated[bool | str | PaddingStrategy | None, padding_validator()]  # 带验证器
    truncation: Annotated[bool | str | TruncationStrategy | None, truncation_validator()]
    max_length: Annotated[int | None, positive_int()]  # 正整数验证
    return_tensors: Annotated[str | TensorType | None, tensor_type_validator()]
    # ... 更多参数
ImagesKwargs --- 图像处理参数
python 复制代码
class ImagesKwargs(TypedDict, total=False):
    do_convert_rgb: bool | None
    do_resize: bool | None
    size: Annotated[int | list[int] | tuple[int, ...] | dict[str, int] | None, image_size_validator()]
    resample: Annotated[Union["PILImageResampling", int] | None, resampling_validator()]
    do_rescale: bool | None
    rescale_factor: float | None
    do_normalize: bool | None
    image_mean: float | list[float] | tuple[float, ...] | None
    image_std: float | list[float] | tuple[float, ...] | None
    device: Annotated[Union[str, "torch.device"] | None, device_validator()]
    return_tensors: Annotated[str | TensorType | None, tensor_type_validator()]
    # ... 更多参数
VideosKwargs --- 视频处理参数
python 复制代码
class VideosKwargs(TypedDict, total=False):
    # 继承了大部分图像参数,额外增加:
    do_sample_frames: bool | None
    video_metadata: Annotated[VideoMetadataType | None, video_metadata_validator()]
    fps: Annotated[int | float | None, positive_any_number()]
    num_frames: Annotated[int | None, positive_int()]
    return_metadata: bool | None
AudioKwargs --- 音频处理参数
python 复制代码
class AudioKwargs(TypedDict, total=False):
    sampling_rate: Annotated[int | None, positive_int()]
    raw_speech: Union["np.ndarray", list[float], ...] | None
    padding: Annotated[bool | str | PaddingStrategy | None, padding_validator()]
    max_length: Annotated[int | None, positive_int()]
    truncation: Annotated[bool | str | TruncationStrategy | None, truncation_validator()]
    return_attention_mask: bool | None
    return_tensors: Annotated[str | TensorType | None, tensor_type_validator()]
ProcessingKwargs --- 聚合容器
python 复制代码
class ProcessingKwargs(TypedDict, total=False):
    _defaults = {}  # 子类可覆盖以提供默认值

    text_kwargs: TextKwargs = {**TextKwargs.__annotations__}
    images_kwargs: ImagesKwargs = {**ImagesKwargs.__annotations__}
    videos_kwargs: VideosKwargs = {**VideosKwargs.__annotations__}
    audio_kwargs: AudioKwargs = {**AudioKwargs.__annotations__}

扩展机制

模型特定的 Processor 可以通过继承扩展 kwargs:

python 复制代码
# 添加模型特有的图像参数
class ModelImagesKwargs(ImagesKwargs, total=False):
    new_image_kwarg: Optional[bool]

# 添加模型特有的处理参数和默认值
class ModelProcessorKwargs(ProcessingKwargs, total=False):
    images_kwargs: ModelImagesKwargs
    _defaults = {
        "images_kwargs": {
            "new_image_kwarg": False,
        },
        "text_kwargs": {
            "padding": "max_length",
        },
    }

验证器系统

Annotated 类型配合验证器实现运行时校验:

验证器 用途
positive_int() 确保正整数
positive_any_number() 确保正数(int 或 float)
padding_validator() 验证 padding 策略
truncation_validator() 验证 truncation 策略
image_size_validator() 验证图像尺寸格式
resampling_validator() 验证重采样方法
device_validator() 验证设备字符串
tensor_type_validator() 验证张量类型
video_metadata_validator() 验证视频元数据

3. 图像处理基础设施

3.1 ImageProcessingMixin(image_processing_base.py)

源文件:[image_processing_base.py](file:///workspace/src/transformers/image_processing_base.py)

模块职责

提供图像处理器的基础设施层,包括:

  1. 序列化/反序列化(to_dict/from_dict/save_pretrained/from_pretrained
  2. 从 Hub 加载配置(get_image_processor_dict
  3. 图像获取(fetch_images
核心类
BatchFeature --- 输出容器
python 复制代码
class BatchFeature(BaseBatchFeature):
    """图像处理器输出的容器,继承自 feature_extraction_utils.BatchFeature"""
    pass
ImageProcessingMixin --- 图像处理器混入类
python 复制代码
class ImageProcessingMixin(PushToHubMixin):
    _auto_class = None

    def __init__(self, **kwargs):
        # 移除过时的属性名
        kwargs.pop("feature_extractor_type", None)
        kwargs.pop("processor_class", None)
        # 将所有 kwargs 设置为实例属性
        for key, value in kwargs.items():
            setattr(self, key, value)

from_pretrained 加载流程

python 复制代码
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
    # 1. 获取配置字典(优先从 processor_config.json 的嵌套配置,回退到独立配置文件)
    image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
    # 2. 从字典实例化
    return cls.from_dict(image_processor_dict, **kwargs)

配置加载优先级get_image_processor_dict):

复制代码
processor_config.json 中的嵌套 "image_processor" 键
    ↓ (若不存在)
preprocessor_config.json 独立文件

fetch_images --- 图像获取

python 复制代码
def fetch_images(self, image_url_or_urls):
    if isinstance(image_url_or_urls, (list, tuple)):
        return [self.fetch_images(x) for x in image_url_or_urls]
    elif isinstance(image_url_or_urls, str):
        return load_image(image_url_or_urls)  # 从 URL 或路径加载
    elif is_valid_image(image_url_or_urls):
        return image_url_or_urls  # 已经是有效图像,直接返回

3.2 BaseImageProcessor(image_processing_utils.py)

源文件:[image_processing_utils.py](file:///workspace/src/transformers/image_processing_utils.py)

模块职责

定义图像处理的标准流水线,包括:

  1. 参数验证和标准化
  2. 输入准备(将原始输入转换为后端格式)
  3. 调度到后端的 _preprocess 方法
  4. 提供向后兼容的 rescale/normalize/center_crop 方法
类继承体系
复制代码
BaseImageProcessor (image_processing_utils.py)
├── TorchvisionBackend (image_processing_backends.py)  ← GPU 加速
│   └── BaseVideoProcessor (video_processing_utils.py)  ← 视频处理
│   └── ModelImageProcessor (各模型目录)  ← 如 LlavaNextImageProcessor
└── PilBackend (image_processing_backends.py)  ← CPU 便携
    └── ModelImageProcessorPil (各模型目录)  ← 如 CLIPImageProcessorPil
预处理流水线
python 复制代码
# 调用链:
__call__() → preprocess() → _preprocess_image_like_inputs()
                                    ↓
                          _prepare_image_like_inputs()  ← 逐图转换格式
                                    ↓
                              _preprocess()  ← 批量操作(由后端实现)

preprocess 方法

python 复制代码
@auto_docstring
def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
    # 1. 类型验证
    validate_typed_dict(self.valid_kwargs, kwargs)

    # 2. 用实例属性填充未提供的 kwargs
    for kwarg_name in self._valid_kwargs_names:
        kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))

    # 3. 标准化 kwargs(如将 size 转为 SizeDict)
    kwargs = self._standardize_kwargs(**kwargs)

    # 4. 验证预处理参数
    self._validate_preprocess_kwargs(**kwargs)

    # 5. 执行预处理
    return self._preprocess_image_like_inputs(images, *args, **kwargs)

_prepare_image_like_inputs --- 逐图格式转换

python 复制代码
def _prepare_image_like_inputs(self, images, *args, expected_ndims=3, **kwargs):
    # 1. 获取并展平图像结构
    images = self._prepare_images_structure(images, expected_ndims=expected_ndims)

    # 2. 对每张图像调用 process_image(由后端实现)
    process_image_partial = partial(self.process_image, *args, **kwargs)
    has_nested_structure = len(images) > 0 and isinstance(images[0], list | tuple)

    if has_nested_structure:
        processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
    else:
        processed_images = [process_image_partial(img) for img in images]

    return processed_images

_standardize_kwargs --- 参数标准化

python 复制代码
def _standardize_kwargs(self, size=None, crop_size=None, pad_size=None, ...):
    # 将各种格式的 size 统一转为 SizeDict
    if size is not None and not isinstance(size, SizeDict):
        size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
    if crop_size is not None and not isinstance(crop_size, SizeDict):
        crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
    # 将 list 类型的 mean/std 转为 tuple(可哈希,用于缓存)
    if isinstance(image_mean, list):
        image_mean = tuple(image_mean)
    # ...
SizeDict 与尺寸规范

get_size_dict 函数将各种尺寸格式统一为标准字典:

python 复制代码
# int + default_to_square=True → {"height": 224, "width": 224}
# int + default_to_square=False → {"shortest_edge": 224}
# (h, w) → {"height": h, "width": w}
# 合法的 key 组合:
VALID_SIZE_DICT_KEYS = (
    {"height", "width"},
    {"shortest_edge"},
    {"shortest_edge", "longest_edge"},
    {"longest_edge"},
    {"max_height", "max_width"},
)

4. 图像处理后端(PIL/Torchvision)

源文件:[image_processing_backends.py](file:///workspace/src/transformers/image_processing_backends.py)

模块职责

提供两种图像处理后端实现,通过继承选择后端:

后端 数据格式 运行设备 特点
TorchvisionBackend torch.Tensor GPU/CPU 批量化操作、融合优化、GPU 加速
PilBackend np.ndarray CPU 便携、无 GPU 依赖、逐图处理

TorchvisionBackend

python 复制代码
@requires(backends=("torch", "torchvision"))
class TorchvisionBackend(BaseImageProcessor):
    """GPU 加速的批量化图像处理后端"""

    def __init__(self, **kwargs: Unpack[ImagesKwargs]):
        super().__init__(**kwargs)
        self._set_attributes(**kwargs)  # 解析并设置所有 valid_kwargs 属性
process_image --- 单图格式转换
python 复制代码
def process_image(self, image, do_convert_rgb=None, input_data_format=None, device=None, **kwargs):
    image_type = get_image_type(image)
    if do_convert_rgb:
        image = self.convert_to_rgb(image)

    # 统一转为 torch.Tensor,channels-first 格式
    if image_type == ImageType.PIL:
        image = tvF.pil_to_tensor(image)      # PIL → Tensor (C, H, W)
    elif image_type == ImageType.NUMPY:
        image = torch.from_numpy(image).contiguous()

    if image.ndim == 2:
        image = image.unsqueeze(0)             # 灰度图添加通道维
    if input_data_format == ChannelDimension.LAST:
        image = image.permute(2, 0, 1).contiguous()  # HWC → CHW
    if device is not None:
        image = image.to(device)
    return image
_preprocess --- 批量化处理流水线
python 复制代码
def _preprocess(self, images, do_resize, size, resample, do_center_crop, crop_size,
                do_rescale, rescale_factor, do_normalize, image_mean, image_std,
                do_pad, pad_size, disable_grouping, return_tensors, **kwargs):
    # 步骤1:按形状分组,批量化 resize
    grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
    resized_images_grouped = {}
    for shape, stacked_images in grouped_images.items():
        if do_resize:
            stacked_images = self.resize(image=stacked_images, size=size, resample=resample)
        resized_images_grouped[shape] = stacked_images
    resized_images = reorder_images(resized_images_grouped, grouped_images_index)

    # 步骤2:重新分组,执行 center_crop + 融合的 rescale_and_normalize
    grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
    processed_images_grouped = {}
    for shape, stacked_images in grouped_images.items():
        if do_center_crop:
            stacked_images = self.center_crop(stacked_images, crop_size)
        stacked_images = self.rescale_and_normalize(
            stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
        )
        processed_images_grouped[shape] = stacked_images
    processed_images = reorder_images(processed_images_grouped, grouped_images_index)

    # 步骤3:padding
    if do_pad:
        processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)

    return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
融合优化 --- rescale_and_normalize
python 复制代码
def rescale_and_normalize(self, images, do_rescale, rescale_factor, do_normalize, image_mean, image_std):
    # 融合 rescale 和 normalize 为单次操作,减少 GPU 内存带宽
    image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
        do_normalize=do_normalize, image_mean=image_mean, image_std=image_std,
        do_rescale=do_rescale, rescale_factor=rescale_factor, device=images.device,
    )
    if do_normalize:
        images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
    elif do_rescale:
        images = self.rescale(images, rescale_factor)
    return images

融合原理normalize(rescale(x, factor), mean, std) 等价于 normalize(x, mean/factor, std/factor),因此可以跳过 rescale 步骤,直接在 normalize 时使用调整后的 mean 和 std。

group_images_by_shape --- 按形状分组批处理
python 复制代码
def group_images_by_shape(images, *paired_inputs, disable_grouping=None, is_nested=False):
    # 自动决策:CPU 上禁用分组(单张处理更快),GPU 上启用分组
    if disable_grouping is None:
        device = _get_device_from_images(images, is_nested)
        disable_grouping = device == "cpu"

    if disable_grouping:
        # 每张图单独 unsqueeze 为 batch=1
        return {key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)}, ...

    # 按形状分组并 stack
    grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(...)
    grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
    return grouped_images, ...

PilBackend

python 复制代码
@requires(backends=("vision",))
class PilBackend(BaseImageProcessor):
    """CPU 便携的图像处理后端,使用 NumPy 操作"""

    def process_image(self, image, do_convert_rgb=None, input_data_format=None, **kwargs):
        image_type = get_image_type(image)
        if do_convert_rgb:
            image = self.convert_to_rgb(image)

        # 统一转为 np.ndarray,channels-first 格式
        if image_type == ImageType.PIL:
            image = np.array(image)
            if image.ndim >= 3:
                input_data_format = ChannelDimension.LAST
        elif image_type == ImageType.TORCH:
            image = image.numpy()

        if image.ndim == 2:
            image = np.expand_dims(image, axis=0)
        if input_data_format == ChannelDimension.LAST:
            image = np.transpose(image, (2, 0, 1))  # HWC → CHW
        return image

PilBackend 的 _preprocess:逐图处理,无批量化优化:

python 复制代码
def _preprocess(self, images, do_resize, size, resample, ...):
    processed_images = []
    for image in images:
        if do_resize:
            image = self.resize(image=image, size=size, resample=resample)
        if do_center_crop:
            image = self.center_crop(image, crop_size)
        if do_rescale:
            image = self.rescale(image, rescale_factor)
        if do_normalize:
            image = self.normalize(image, image_mean, image_std)
        processed_images.append(image)
    if do_pad:
        processed_images = self.pad(processed_images, pad_size=pad_size)
    return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

5. 图像变换

源文件:[image_transforms.py](file:///workspace/src/transformers/image_transforms.py)

模块职责

提供与后端无关的底层图像变换函数 ,基于 NumPy 实现(PIL 后端直接使用,Torchvision 后端使用对应的 Torch 实现)。所有函数都支持 ChannelDimension 格式感知。

核心函数

to_channel_dimension_format --- 通道维度转换
python 复制代码
def to_channel_dimension_format(image, channel_dim, input_channel_dim=None):
    """将图像转换为指定的通道维度格式(channels_first 或 channels_last)"""
    if input_channel_dim is None:
        input_channel_dim = infer_channel_dimension_format(image)
    if input_channel_dim == target_channel_dim:
        return image
    # 通过 transpose 交换最后三个维度
    if target_channel_dim == ChannelDimension.FIRST:
        axes = list(range(image.ndim - 3)) + [image.ndim - 1, image.ndim - 3, image.ndim - 2]
        image = image.transpose(axes)
    # ...
rescale --- 像素值缩放
python 复制代码
def rescale(image, scale, data_format=None, dtype=np.float32, input_data_format=None):
    """将像素值乘以缩放因子(通常 1/255,从 [0,255] 缩放到 [0,1])"""
    rescaled_image = image.astype(np.float64) * scale  # 先上转避免精度损失
    if data_format is not None:
        rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
    rescaled_image = rescaled_image.astype(dtype)
    return rescaled_image
normalize --- 标准化
python 复制代码
def normalize(image, mean, std, data_format=None, input_data_format=None):
    """标准化:image = (image - mean) / std"""
    channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
    num_channels = image.shape[channel_axis]
    # 将标量 mean/std 扩展为与通道数匹配的数组
    if isinstance(mean, Collection):
        if len(mean) != num_channels:
            raise ValueError(...)
    else:
        mean = [mean] * num_channels
    mean = np.array(mean, dtype=image.dtype)
    # ... 同理处理 std
    if input_data_format == ChannelDimension.LAST:
        image = (image - mean) / std
    else:
        image = ((image.T - mean) / std).T  # channels-first 需要转置广播
resize --- 缩放
python 复制代码
def resize(image, size, resample=None, reducing_gap=None, data_format=None,
           return_numpy=True, input_data_format=None):
    """使用 PIL 库缩放图像(即使输入是 NumPy 数组也先转 PIL 再缩放)"""
    # 检测是否需要先 rescale 再转 PIL
    do_rescale = False
    if not isinstance(image, PIL.Image.Image):
        do_rescale = _rescale_for_pil_conversion(image)
        image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
    # PIL 缩放(注意 PIL 的尺寸顺序是 width, height)
    resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
    if return_numpy:
        resized_image = np.array(resized_image)
        # 恢复通道维度格式和原始值范围
        resized_image = to_channel_dimension_format(resized_image, data_format, ...)
        resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
    return resized_image
center_crop --- 中心裁剪
python 复制代码
def center_crop(image, size, data_format=None, input_data_format=None):
    """中心裁剪,若图像小于目标尺寸则先零填充再裁剪"""
    image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
    # 计算裁剪边界
    top = (orig_height - crop_height) // 2
    # ...
    if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
        image = image[..., top:bottom, left:right]  # 直接裁剪
    else:
        # 图像太小,先零填充再裁剪
        new_image = np.zeros_like(image, shape=new_shape)
        new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
        new_image = new_image[..., max(0, top):min(new_height, bottom), ...]
pad --- 填充
python 复制代码
def pad(image, padding, mode=PaddingMode.CONSTANT, constant_values=0.0, ...):
    """使用 np.pad 填充图像,支持 constant/reflect/replicate/symmetric 模式"""
    # 根据 data_format 展开填充参数
    padding = _expand_for_data_format(padding)
    if mode == PaddingMode.CONSTANT:
        image = np.pad(image, padding, mode="constant", constant_values=constant_values)
    elif mode == PaddingMode.REFLECT:
        image = np.pad(image, padding, mode="reflect")
    # ...
辅助函数
函数 用途
get_resize_output_image_size 计算缩放后的目标尺寸
get_size_with_aspect_ratio 保持宽高比计算缩放尺寸
select_best_resolution 从候选分辨率中选择最佳匹配
divide_to_patches 将图像分割为 patch
group_images_by_shape 按形状分组图像(用于批量化)
reorder_images 恢复图像的原始顺序
center_to_corners_format 边界框格式转换(中心→角点)
rgb_to_id / id_to_rgb 颜色与 ID 互转(全景分割)

6. 图像工具

源文件:[image_utils.py](file:///workspace/src/transformers/image_utils.py)

模块职责

提供图像处理系统的基础工具层,包括:

  1. 类型定义(ImageInputChannelDimensionImageType
  2. 图像验证和分类
  3. 图像加载(URL、本地路径、base64)
  4. 通道维度推断
  5. 图像列表结构规范化

核心类型定义

python 复制代码
# 图像输入的联合类型
ImageInput = Union[
    "PIL.Image.Image", np.ndarray, "torch.Tensor",
    list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"]
]

# 通道维度枚举
class ChannelDimension(ExplicitEnum):
    FIRST = "channels_first"   # (C, H, W)
    LAST = "channels_last"     # (H, W, C)

# 图像类型枚举
class ImageType(ExplicitEnum):
    PIL = "pillow"
    TORCH = "torch"
    NUMPY = "numpy"

核心函数

infer_channel_dimension_format --- 推断通道维度
python 复制代码
def infer_channel_dimension_format(image, num_channels=None):
    """推断图像的通道维度位置(first 或 last)"""
    num_channels = num_channels if num_channels is not None else (1, 3)
    if image.ndim == 3:
        first_dim, last_dim = 0, 2
    elif image.ndim == 4:
        first_dim, last_dim = 1, 3
    # 检查第一个维度和最后一个维度哪个匹配通道数
    if image.shape[first_dim] in num_channels:
        return ChannelDimension.FIRST
    elif image.shape[last_dim] in num_channels:
        return ChannelDimension.LAST
make_list_of_images / make_flat_list_of_images --- 结构规范化
python 复制代码
def make_list_of_images(images, expected_ndims=3):
    """确保输出为图像列表:单图 → [图],批量 → [图1, 图2, ...]"""
    if is_batched(images):
        return images
    if is_pil_image(images):
        return [images]
    if is_valid_image(images):
        if images.ndim == expected_ndims + 1:
            images = list(images)  # 4D 批量 → 列表
        elif images.ndim == expected_ndims:
            images = [images]      # 3D 单图 → 列表

def make_flat_list_of_images(images, expected_ndims=3):
    """确保输出为扁平的图像列表:嵌套列表 → 扁平列表"""
    if isinstance(images, (list, tuple)) and all(isinstance(i, (list, tuple)) for i in images):
        return [img for img_list in images for img in img_list]
load_image --- 图像加载

支持从 URL、本地路径、base64 字符串加载图像,内部使用 httpx 获取远程资源,PIL 打开图像文件。


7. 特征提取系统

7.1 FeatureExtractionMixin(feature_extraction_utils.py)

源文件:[feature_extraction_utils.py](file:///workspace/src/transformers/feature_extraction_utils.py)

模块职责

提供特征提取器的基础设施 ,与 ImageProcessingMixin 结构类似:

  1. BatchFeature --- 所有处理器输出的统一容器
  2. FeatureExtractionMixin --- 保存/加载/序列化的混入类
BatchFeature --- 核心输出容器
python 复制代码
class BatchFeature(UserDict):
    """所有处理器 __call__ 方法的输出容器,本质是一个字典"""

    def __init__(self, data=None, tensor_type=None, skip_tensor_conversion=None):
        super().__init__(data)
        self.skip_tensor_conversion = skip_tensor_conversion
        self.convert_to_tensors(tensor_type=tensor_type)  # 自动转换张量类型

    def convert_to_tensors(self, tensor_type=None, skip_tensor_conversion=None):
        """将内部数据转换为指定类型的张量(PyTorch 或 NumPy)"""
        is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
        for key, value in self.items():
            if skip_tensor_conversion and key in skip_tensor_conversion:
                continue
            if not _is_tensor_or_array_like(value):
                continue
            if not is_tensor(value):
                self[key] = as_tensor(value)

    def to(self, *args, **kwargs):
        """将所有张量移到指定设备(PyTorch only)"""
        # 只转换浮点张量的 dtype,其他张量只移设备
        def maybe_to(v):
            if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
                return v.to(*args, **kwargs)
            elif isinstance(v, torch.Tensor) and device is not None:
                return v.to(device=device, non_blocking=non_blocking)
            elif isinstance(v, (list, tuple)):
                return type(v)(maybe_to(item) for item in v)
            else:
                return v
        self.data = {k: maybe_to(v) for k, v in self.items()}

7.2 SequenceFeatureExtractor(feature_extraction_sequence_utils.py)

源文件:[feature_extraction_sequence_utils.py](file:///workspace/src/transformers/feature_extraction_sequence_utils.py)

模块职责

序列型特征提取器(主要用于音频)提供:

  1. Padding/Truncation 策略
  2. 注意力掩码生成
  3. 批量化处理
python 复制代码
class SequenceFeatureExtractor(FeatureExtractionMixin):
    def __init__(self, feature_size, sampling_rate, padding_value, **kwargs):
        self.feature_size = feature_size
        self.sampling_rate = sampling_rate
        self.padding_value = padding_value
        self.padding_side = kwargs.pop("padding_side", "right")
        self.return_attention_mask = kwargs.pop("return_attention_mask", True)

    def pad(self, processed_features, padding=True, max_length=None,
            truncation=False, pad_to_multiple_of=None, ...):
        """对特征序列进行 padding,支持多种策略"""
        # 1. 将 list[dict] 转为 dict[list](兼容 DataLoader collate_fn)
        # 2. 统一转为 NumPy 数组
        # 3. 执行 truncation
        # 4. 根据 padding 策略执行 padding
        #   - LONGEST: pad 到批次中最长序列
        #   - MAX_LENGTH: pad 到 max_length
        # 5. 生成 attention_mask

8. 音频工具

源文件:[audio_utils.py](file:///workspace/src/transformers/audio_utils.py)

模块职责

提供音频处理的底层工具

  1. 音频加载(支持多种后端)
  2. 音频格式转换
  3. 频谱分析工具(Mel 滤波器、STFT 等)
  4. 音频验证和结构规范化

核心类型和函数

AudioInput 类型
python 复制代码
AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]]
load_audio --- 音频加载
python 复制代码
def load_audio(audio, sampling_rate=16000, timeout=None):
    """加载音频为 NumPy 数组,支持 URL 和本地路径"""
    if isinstance(audio, str):
        # 优先使用 torchcodec(更高效),回退到 librosa
        if is_torchcodec_available() and version.parse("0.3.0") <= TORCHCODEC_VERSION:
            audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate, timeout=timeout)
        else:
            audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
    elif not isinstance(audio, np.ndarray):
        raise TypeError(...)
    return audio

音频加载后端

后端 依赖 特点
torchcodec torchcodec ≥ 0.3.0 推荐,高效
librosa librosa + soxr 传统方案,广泛兼容
Mel 频谱工具
python 复制代码
def hertz_to_mel(freq, mel_scale="htk"):
    """Hz → Mel 频率转换,支持 htk/kaldi/slaney 三种尺度"""
    if mel_scale == "htk":
        return 2595.0 * np.log10(1.0 + (freq / 700.0))
    # ...

def mel_to_hertz(mels, mel_scale="htk"):
    """Mel → Hz 频率转换"""
    if mel_scale == "htk":
        return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
    # ...

def _create_triangular_filter_bank(fft_freqs, filter_freqs):
    """创建三角滤波器组(用于 Mel 频谱图)"""
    filter_diff = np.diff(filter_freqs)
    slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
    down_slopes = -slopes[:, :-2] / filter_diff[:-1]
    up_slopes = slopes[:, 2:] / filter_diff[1:]
    return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))

9. 视频处理系统

9.1 video_utils.py --- 视频工具

源文件:[video_utils.py](file:///workspace/src/transformers/video_utils.py)

模块职责

提供视频处理的底层工具

  1. 视频输入类型定义和验证
  2. 视频元数据(VideoMetadata
  3. 视频解码(支持 5 种后端)
  4. 帧采样策略
  5. 视频结构规范化
VideoInput 类型
python 复制代码
VideoInput = Union[
    list["PIL.Image.Image"],          # 帧列表
    np.ndarray,                        # 4D 数组 (T, H, W, C) 或 (T, C, H, W)
    "torch.Tensor",                    # 4D 张量
    list[np.ndarray], list["torch.Tensor"],
    list[list["PIL.Image.Image"]],     # 批量帧列表
    URL, list[URL], list[list[URL]],   # URL
    Path, list[Path], list[list[Path]], # 本地路径
]
VideoMetadata --- 视频元数据
python 复制代码
@dataclass
class VideoMetadata(Mapping):
    total_num_frames: int
    fps: float | None = None
    width: int | None = None
    height: int | None = None
    duration: float | None = None
    video_backend: str | None = None
    frames_indices: list[int] | None = None

    @property
    def timestamps(self) -> list[float]:
        """采样帧的时间戳(秒)"""
        return [frame_idx / self.fps for frame_idx in self.frames_indices]

    @property
    def sampled_fps(self) -> float:
        """采样后的实际 FPS"""
        return len(self.frames_indices) / self.total_num_frames * self.fps
视频解码后端
python 复制代码
VIDEO_DECODERS = {
    "decord": read_video_decord,       # Decord 后端
    "opencv": read_video_opencv,       # OpenCV 后端
    "pyav": read_video_pyav,           # PyAV 后端(默认)
    "torchvision": read_video_torchvision,  # Torchvision 后端(已弃用)
    "torchcodec": read_video_torchcodec,    # Torchcodec 后端(推荐)
}
load_video --- 视频加载
python 复制代码
def load_video(video, num_frames=None, fps=None, backend="pyav",
               sample_indices_fn=None, **kwargs):
    """加载视频为 NumPy 数组"""
    # 1. 创建采样函数(默认均匀采样)
    if sample_indices_fn is None:
        def sample_indices_fn_func(metadata, **fn_kwargs):
            return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
        sample_indices_fn = sample_indices_fn_func

    # 2. 如果已经是数组/PIL帧,直接返回
    if not isinstance(video, str):
        return video, [None] * len(video)

    # 3. 处理 YouTube URL(需要 yt_dlp)
    # 4. 处理 HTTP URL(使用 httpx 下载)
    # 5. 使用指定后端解码
    video_decoder = VIDEO_DECODERS[backend]
    video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
    return video, metadata
帧采样策略
python 复制代码
def default_sample_indices_fn(metadata, num_frames=None, fps=None, **kwargs):
    """默认均匀采样策略"""
    total_num_frames = metadata.total_num_frames
    video_fps = metadata.fps

    # 如果指定了 fps,从 fps 计算 num_frames
    if num_frames is None and fps is not None:
        num_frames = int(total_num_frames / video_fps * fps)

    if num_frames is not None:
        indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
    else:
        indices = np.arange(0, total_num_frames, dtype=int)
    return indices

9.2 video_processing_utils.py --- 视频处理器

源文件:[video_processing_utils.py](file:///workspace/src/transformers/video_processing_utils.py)

模块职责

提供视频处理器的完整实现,继承自 TorchvisionBackend,增加了:

  1. 视频解码和帧采样
  2. 视频元数据管理
  3. 视频专用的 RGB 转换
BaseVideoProcessor
python 复制代码
@requires(backends=("vision", "torchvision"))
class BaseVideoProcessor(TorchvisionBackend):
    model_input_names = ["pixel_values_videos"]  # 注意:与图像的 "pixel_values" 不同
    valid_kwargs = VideosKwargs

    def preprocess(self, videos, **kwargs):
        # 1. 参数验证
        validate_kwargs(...)
        validate_typed_dict(self.valid_kwargs, kwargs)

        # 2. 填充默认值
        for kwarg_name in self.valid_kwargs.__annotations__:
            kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))

        # 3. 解码和帧采样
        sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
        videos, video_metadata = self._decode_and_sample_videos(
            videos, video_metadata=video_metadata,
            do_sample_frames=do_sample_frames, sample_indices_fn=sample_indices_fn,
        )

        # 4. 准备输入视频(转为 Tensor、调整通道维度)
        videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device)

        # 5. 标准化和验证 kwargs
        kwargs = self._standardize_kwargs(**kwargs)
        self._validate_preprocess_kwargs(**kwargs)

        # 6. 执行预处理(复用 TorchvisionBackend 的 _preprocess)
        preprocessed_videos = self._preprocess(videos=videos, **kwargs)
        if return_metadata:
            preprocessed_videos["video_metadata"] = video_metadata
        return preprocessed_videos

_decode_and_sample_videos --- 解码和采样

python 复制代码
def _decode_and_sample_videos(self, videos, video_metadata, do_sample_frames, sample_indices_fn):
    videos = make_batched_videos(videos)
    video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)

    if is_valid_video(videos[0]) and do_sample_frames:
        # 已加载的视频数组,直接按索引采样帧
        sampled_videos = []
        for video, metadata in zip(videos, video_metadata):
            indices = sample_indices_fn(metadata=metadata)
            metadata.frames_indices = indices
            sampled_videos.append(video[indices])
        videos = sampled_videos
    elif not is_valid_video(videos[0]):
        # URL/路径,需要先解码
        if isinstance(videos[0], list):
            # 图像 URL 列表 → 逐帧加载并 stack
            videos = [torch.stack([self.process_image(image) for image in images])
                      for images in self.fetch_images(videos)]
        else:
            videos, video_metadata = self.fetch_videos(videos, sample_indices_fn=sample_indices_fn)
    return videos, video_metadata

10. 视觉工具

源文件:[vision_utils.py](file:///workspace/src/transformers/vision_utils.py)

模块职责

提供视觉编码器的预计算工具函数 ,主要用于解决 torch.compile / torch.export 追踪中的动态计算问题。这些函数计算视觉编码器所需的辅助张量(如累积序列长度、位置 ID、窗口注意力索引等),可以在模型编译前预计算。

核心函数

get_vision_cu_seqlens --- 累积序列长度
python 复制代码
def get_vision_cu_seqlens(grid_thw, kwargs=None):
    """从 grid_thw 计算累积序列长度,用于 Flash Attention 等变长序列操作"""
    if kwargs is not None and (cu_seqlens := kwargs.pop("cu_seqlens", None)) is not None:
        return cu_seqlens  # 如果已预计算,直接返回

    # grid_thw: (num_images, 3) --- 每个图像的 (temporal, height, width)
    # 计算每个图像的 patch 数 = h * w,按 temporal 维度重复后累加
    cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(...)
    return F.pad(cu_seqlens, (1, 0), value=0)  # 前补 0
get_vision_position_ids --- 位置 ID
python 复制代码
def get_vision_position_ids(grid_thw, spatial_merge_size, kwargs=None):
    """计算视觉 Rotary Embedding 的 (row, col) 位置 ID"""
    # 对每个图像,生成 h×w 的网格位置,考虑 spatial merge
    for (t, h, w), merge_size in zip(grid_thw.tolist(), spatial_merge_size.tolist()):
        hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
        hpos_ids = hpos_ids.reshape(h // merge_size, merge_size, w // merge_size, merge_size)
        hpos_ids = hpos_ids.transpose(1, 2).flatten()
        # ... 同理计算 wpos_ids
        position_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
    return torch.cat(position_ids, dim=0)
get_vision_window_index --- 窗口注意力索引
python 复制代码
def get_vision_window_index(grid_thw, spatial_merge_size, window_size, patch_size, kwargs=None):
    """计算窗口注意力的重排序索引和窗口边界"""
    # 将图像划分为不重叠的窗口,生成窗口内的 token 重排索引
    vit_merger_window_size = window_size // spatial_merge_size // patch_size
    for grid_t, grid_h, grid_w in grid_thw.tolist():
        # 创建索引网格,padding 到窗口大小的整数倍
        index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(...)
        index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
        # 重排为窗口格式
        index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(...)
    return window_index, cu_window_seqlens
get_vision_bilinear_indices_and_weights --- 双线性插值
python 复制代码
def get_vision_bilinear_indices_and_weights(grid_thw, num_grid_per_side, spatial_merge_size, kwargs=None):
    """计算双线性插值的索引和权重,用于位置编码的二维插值"""
    # 对每个图像,在 h 和 w 方向上生成插值网格
    for t, h, w in grid_thw.tolist():
        h_grid = torch.linspace(0, side - 1, h)
        w_grid = torch.linspace(0, side - 1, w)
        # 计算四个角的索引和权重
        corner_indices = [h_floor + w_floor, h_floor + w_ceil, h_ceil + w_floor, h_ceil + w_ceil]
        corner_weights = [(1-h_frac)*(1-w_frac), (1-h_frac)*w_frac, h_frac*(1-w_frac), h_frac*w_frac]
    return bilinear_indices, bilinear_weights

预计算模式 :所有 get_* 函数都接受 kwargs 参数,如果 kwargs 中已包含预计算结果,则直接返回而不重新计算。这使得视觉编码器可以在 torch.compile 之前预计算这些动态张量。


11. 模块间关系与数据流

完整数据流

以多模态输入 processor(text="描述", images=img, videos=vid, audio=aud) 为例:

复制代码
用户调用 processor(text, images, videos, audio, **kwargs)
    │
    ▼
ProcessorMixin.__call__()
    │
    ├── _merge_kwargs() ──── 合并所有来源的参数
    │   ├── ProcessingKwargs._defaults(最低优先级)
    │   ├── tokenizer.init_kwargs
    │   ├── common_kwargs
    │   └── 调用时传入的 kwargs(最高优先级)
    │
    ├── tokenizer(text, **text_kwargs) ──── 文本处理
    │   └── 返回 {"input_ids": ..., "attention_mask": ...}
    │
    ├── image_processor(images, **images_kwargs) ──── 图像处理
    │   ├── BaseImageProcessor.preprocess()
    │   │   ├── validate_typed_dict() ──── 类型验证
    │   │   ├── _standardize_kwargs() ──── 参数标准化
    │   │   └── _preprocess_image_like_inputs()
    │   │       ├── _prepare_image_like_inputs() ──── 逐图格式转换
    │   │       │   └── process_image() ──── [TorchvisionBackend/PilBackend]
    │   │       └── _preprocess() ──── 批量处理
    │   │           ├── group_images_by_shape() ──── 按形状分组
    │   │           ├── resize() ──── 缩放
    │   │           ├── center_crop() ──── 中心裁剪
    │   │           ├── rescale_and_normalize() ──── 融合缩放+标准化
    │   │           ├── pad() ──── 填充
    │   │           └── reorder_images() ──── 恢复原始顺序
    │   └── 返回 {"pixel_values": ...}
    │
    ├── video_processor(videos, **videos_kwargs) ──── 视频处理
    │   ├── BaseVideoProcessor.preprocess()
    │   │   ├── _decode_and_sample_videos() ──── 解码+帧采样
    │   │   │   ├── make_batched_videos()
    │   │   │   ├── sample_frames() / load_video()
    │   │   │   └── _prepare_input_videos()
    │   │   └── _preprocess() ──── 复用 TorchvisionBackend
    │   └── 返回 {"pixel_values_videos": ..., "video_metadata": ...}
    │
    ├── feature_extractor(audio, **audio_kwargs) ──── 音频处理
    │   ├── SequenceFeatureExtractor.__call__()
    │   │   ├── load_audio() ──── 加载音频
    │   │   └── pad() ──── padding + attention_mask
    │   └── 返回 {"input_values": ..., "attention_mask": ...}
    │
    └── BatchFeature({
        "input_ids": ...,          ← 来自 tokenizer
        "attention_mask": ...,     ← 来自 tokenizer
        "pixel_values": ...,       ← 来自 image_processor
        "pixel_values_videos": ...,← 来自 video_processor
        "input_values": ...,       ← 来自 feature_extractor
    })

模块依赖关系图

复制代码
processing_utils.py
    ├── image_utils.py (ImageInput, ChannelDimension)
    ├── video_utils.py (VideoInput, VideoMetadata)
    ├── audio_utils.py (AudioInput, load_audio)
    ├── feature_extraction_utils.py (BatchFeature)
    ├── tokenization_utils_base.py (PaddingStrategy, TruncationStrategy)
    └── utils/type_validators.py (验证器)

image_processing_utils.py (BaseImageProcessor)
    ├── image_processing_base.py (ImageProcessingMixin, BatchFeature)
    ├── image_transforms.py (resize, normalize, rescale, center_crop, pad)
    ├── image_utils.py (ImageInput, ChannelDimension, SizeDict)
    └── processing_utils.py (ImagesKwargs)

image_processing_backends.py (TorchvisionBackend, PilBackend)
    ├── image_processing_utils.py (BaseImageProcessor)
    ├── image_transforms.py (各种变换函数)
    └── image_utils.py (ImageInput, ImageType)

video_processing_utils.py (BaseVideoProcessor)
    ├── image_processing_backends.py (TorchvisionBackend)
    ├── video_utils.py (VideoInput, VideoMetadata, load_video)
    └── processing_utils.py (VideosKwargs)

vision_utils.py (独立模块,仅依赖 torch)
    └── 被视觉编码器模型代码调用

feature_extraction_utils.py (BatchFeature, FeatureExtractionMixin)
    └── 独立基础设施

feature_extraction_sequence_utils.py (SequenceFeatureExtractor)
    ├── feature_extraction_utils.py (BatchFeature, FeatureExtractionMixin)
    └── audio_utils.py (load_audio, is_valid_audio)

audio_utils.py (独立工具模块)
    └── 依赖 soundfile, librosa, torchcodec(可选)

关键设计模式总结

  1. Mixin 模式ProcessorMixinImageProcessingMixinFeatureExtractionMixin 都继承自 PushToHubMixin,通过混入提供保存/加载能力。

  2. 策略模式:图像处理后端(Torchvision/PIL)通过继承选择,而非运行时切换,确保类型安全和性能优化。

  3. TypedDict 驱动的参数系统 :使用 Python 的 TypedDict + Annotated 实现类型安全的参数传递和验证,替代了传统的松散 **kwargs

  4. 优先级合并_merge_kwargs 实现了四级优先级的参数合并,确保用户调用时的参数具有最高优先级,同时支持模型级别的默认值。

  5. 批量化优化 :Torchvision 后端通过 group_images_by_shape + reorder_images 实现自动批量化,相同形状的图像 stack 后批量处理,不同形状的分组处理。

  6. 融合优化rescale_and_normalize 将两步操作融合为一步,减少 GPU 内存带宽占用。

  7. 预计算模式vision_utils.py 中的函数支持从 kwargs 中读取预计算结果,解决 torch.compile 追踪动态计算的问题。

相关推荐
tedcloud1239 分钟前
RTK部署教程:构建稳定的AI Workflow环境
服务器·javascript·人工智能·typescript·ocr
Raink老师11 分钟前
【AI面试临阵磨枪-71】如何用 AI 优化推荐系统、内容审核、广告创意、搜索体验?
人工智能·面试·职场和发展
AI医影跨模态组学13 分钟前
Biomarker Res(IF=11.5)安徽医科大学第一医院:基于机器学习的放射组学模型:子宫内膜癌患者的预后预测及机制探索
人工智能·深度学习·论文·医学·医学影像·影像组学
ftpeak19 分钟前
Mooncake:以 KVCache 为中心的分离式 LLM 服务架构
人工智能·ai·架构·ai编程·ai开发
Terrence Shen24 分钟前
Hermes agent的tools是怎么落地应用的系列
人工智能·llm·agent·hermes
Raink老师39 分钟前
【AI面试临阵磨枪-72】电商全场景 AI Agent 设计(商品咨询 / 订单 / 物流 / 售后 / 退款)
人工智能·面试·职场和发展
仙女修炼史1 小时前
CNN更看重Texture还是shape:imagenet-trained cnns are biased
论文阅读·人工智能·cnn
视***间1 小时前
视程空间 AIR SC6N0-C-MB NX 16GB 规格详解与机器人/机器狗适配说明
人工智能·机器人·边缘计算·机器狗·ai算力·具身机器人·视程空间
视***间1 小时前
小身板・强算力・全适配 —— 视程空间 AI 算力开发板如何完美适配机器人 / 机器狗
人工智能·机器人·边缘计算·ai算力·视程空间·算力开发板
网宿安全演武实验室1 小时前
当AI跑进容器:全链路容器安全检测与智能运营实
人工智能·安全·容器·k8s