16-Hugging Face Transformers之测试体系架构总览

Transformers 测试体系

相关文章:

Hugging Face Transformers 源码全景解读

01-Hugging Face Transformers 核心基础设施深度分析

02-Hugging Face Transformers 配置系统深度分析

03-Hugging Face Transformers 模型系统深度分析

04-Hugging Face Transformers 注意力与掩码系统深度分析

05-Hugging Face Transformers 缓存系统深度分析

06-Hugging Face Transformers 生成系统深度分析

07-Hugging Face Transformers 分词器系统深度分析

08-Hugging Face Transformers 多模态处理系统深度分析

09-Hugging Face Transformers 训练系统深度分析

10-Hugging Face Transformers 量化系统深度分析

11-Hugging Face Transformers 分布式与并行系统深度分析

12-Hugging Face Transformers之Pipeline 推理管道深入分析

13-Hugging Face Transformers之AutoModel 自动分发机制深入分析

14-Hugging Face Transformers 模型实现模式深度分析

15-Hugging Face Transformers之CLI 与工具架构总览

16-Hugging Face Transformers之测试体系架构总览

17-Hugging Face Transformers之BERT 案例详解:Transformers 框架全模块串联

18-Hugging Face Transformers之GPT-2 案例详解:Decoder-only 自回归模型的完整生命周期

19-Hugging Face Transformers之Qwen3.5-MoE 系列详解:混合专家 + 线性注意力 + 多模态的完整生命周期

测试体系架构总览

#mermaid-svg-Y58dQd9q5d61P08x{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-Y58dQd9q5d61P08x .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-Y58dQd9q5d61P08x .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-Y58dQd9q5d61P08x .error-icon{fill:#552222;}#mermaid-svg-Y58dQd9q5d61P08x .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-Y58dQd9q5d61P08x .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-Y58dQd9q5d61P08x .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-Y58dQd9q5d61P08x .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-Y58dQd9q5d61P08x .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-Y58dQd9q5d61P08x .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-Y58dQd9q5d61P08x .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-Y58dQd9q5d61P08x .marker{fill:#333333;stroke:#333333;}#mermaid-svg-Y58dQd9q5d61P08x .marker.cross{stroke:#333333;}#mermaid-svg-Y58dQd9q5d61P08x svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-Y58dQd9q5d61P08x p{margin:0;}#mermaid-svg-Y58dQd9q5d61P08x .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-Y58dQd9q5d61P08x .cluster-label text{fill:#333;}#mermaid-svg-Y58dQd9q5d61P08x .cluster-label span{color:#333;}#mermaid-svg-Y58dQd9q5d61P08x .cluster-label span p{background-color:transparent;}#mermaid-svg-Y58dQd9q5d61P08x .label text,#mermaid-svg-Y58dQd9q5d61P08x span{fill:#333;color:#333;}#mermaid-svg-Y58dQd9q5d61P08x .node rect,#mermaid-svg-Y58dQd9q5d61P08x .node circle,#mermaid-svg-Y58dQd9q5d61P08x .node ellipse,#mermaid-svg-Y58dQd9q5d61P08x .node polygon,#mermaid-svg-Y58dQd9q5d61P08x .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-Y58dQd9q5d61P08x .rough-node .label text,#mermaid-svg-Y58dQd9q5d61P08x .node .label text,#mermaid-svg-Y58dQd9q5d61P08x .image-shape .label,#mermaid-svg-Y58dQd9q5d61P08x .icon-shape .label{text-anchor:middle;}#mermaid-svg-Y58dQd9q5d61P08x .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-Y58dQd9q5d61P08x .rough-node .label,#mermaid-svg-Y58dQd9q5d61P08x .node .label,#mermaid-svg-Y58dQd9q5d61P08x .image-shape .label,#mermaid-svg-Y58dQd9q5d61P08x .icon-shape .label{text-align:center;}#mermaid-svg-Y58dQd9q5d61P08x .node.clickable{cursor:pointer;}#mermaid-svg-Y58dQd9q5d61P08x .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-Y58dQd9q5d61P08x .arrowheadPath{fill:#333333;}#mermaid-svg-Y58dQd9q5d61P08x .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-Y58dQd9q5d61P08x .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-Y58dQd9q5d61P08x .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-Y58dQd9q5d61P08x .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-Y58dQd9q5d61P08x .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-Y58dQd9q5d61P08x .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-Y58dQd9q5d61P08x .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-Y58dQd9q5d61P08x .cluster text{fill:#333;}#mermaid-svg-Y58dQd9q5d61P08x .cluster span{color:#333;}#mermaid-svg-Y58dQd9q5d61P08x div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-Y58dQd9q5d61P08x .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-Y58dQd9q5d61P08x rect.text{fill:none;stroke-width:0;}#mermaid-svg-Y58dQd9q5d61P08x .icon-shape,#mermaid-svg-Y58dQd9q5d61P08x .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-Y58dQd9q5d61P08x .icon-shape p,#mermaid-svg-Y58dQd9q5d61P08x .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-Y58dQd9q5d61P08x .icon-shape .label rect,#mermaid-svg-Y58dQd9q5d61P08x .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-Y58dQd9q5d61P08x .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-Y58dQd9q5d61P08x .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-Y58dQd9q5d61P08x :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 专项测试层
模型特定测试层
Mixin 测试层
pytest 配置层
conftest.py

全局配置/钩子
pytest_configure

注册标记
ModelTesterMixin

模型通用测试
ConfigTester

配置通用测试
TokenizerTesterMixin

分词器通用测试
PipelineTestMixin

Pipeline 测试
TrainingTestMixin

训练测试
TensorParallelTestMixin

张量并行测试
test_modeling_{name}.py
test_configuration_{name}.py
test_tokenization_{name}.py
generation/

生成测试
trainer/

训练器测试
pipelines/

Pipeline 测试
cli/

CLI 测试

一、测试体系总览

Transformers 的测试体系是一个多层级、高度参数化的架构,核心设计理念是通过 Mixin 类实现测试逻辑复用,每个模型只需提供配置和输入数据,即可自动继承数百个通用测试用例。

测试目录结构

复制代码
tests/
├── conftest.py                    # pytest 全局配置
├── test_modeling_common.py        # 模型通用测试 Mixin(ModelTesterMixin)
├── test_configuration_common.py   # 配置通用测试(ConfigTester)
├── test_tokenization_common.py    # 分词器通用测试 Mixin(TokenizerTesterMixin)
├── test_pipeline_mixin.py         # Pipeline 测试 Mixin
├── test_training_mixin.py         # 训练测试 Mixin
├── test_tensor_parallel_mixin.py  # 张量并行测试 Mixin
├── models/                        # 各模型特定测试
│   └── {model_name}/
│       ├── test_modeling_{name}.py
│       ├── test_configuration_{name}.py
│       └── test_tokenization_{name}.py
├── utils/                         # 工具类测试
├── generation/                    # 生成相关测试
├── trainer/                       # 训练器测试
├── pipelines/                     # Pipeline 测试
├── cli/                           # CLI 测试
│   ├── conftest.py
│   ├── test_chat.py
│   ├── test_download.py
│   ├── test_serve.py
│   └── test_system.py
└── fixtures/                      # 测试固件文件
    ├── sample_text.txt
    ├── vocab.txt / vocab.json
    └── ...

二、pytest 全局配置 --- conftest.py

2.1 核心配置

python 复制代码
# conftest.py --- tests 目录级配置,pytest 自动加载

import doctest
import _pytest
import pytest
from transformers.testing_utils import (
    HfDoctestModule, HfDocTestParser,
    is_torch_available, patch_testing_methods_to_collect_info,
    patch_torch_compile_force_graph,
)
from transformers.utils import enable_tf32

# 非 GPU 测试列表 --- 这些测试始终在 CPU 上运行
NOT_DEVICE_TESTS = {
    "test_tokenization", "test_configuration_utils",
    "test_data_collator", "test_optimization",
    "test_forward_signature", "test_model_get_set_embeddings",
    "ModelTest::test_pipeline_", "ModelTester::test_pipeline_",
    "/repo_utils/", "/utils/",
}

# 确保 src 目录在 sys.path 中(支持多仓库检出)
git_repo_path = abspath(join(dirname(__file__), "src"))
sys.path.insert(1, git_repo_path)

# 忽略 FutureWarning(测试中无法立即处理的弃用警告)
warnings.simplefilter(action="ignore", category=FutureWarning)

2.2 pytest 钩子函数

python 复制代码
def pytest_configure(config):
    """注册自定义标记"""
    config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested")
    config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
    config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
    config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")
    config.addinivalue_line("markers", "torch_compile_test: mark test which tests torch compile")
    config.addinivalue_line("markers", "flash_attn_test: mark test which tests flash attention")
    config.addinivalue_line("markers", "training_ci: mark test for training CI validation")
    config.addinivalue_line("markers", "tensor_parallel_ci: mark test for tensor parallel CI validation")
    os.environ["DISABLE_SAFETENSORS_CONVERSION"] = "true"

def pytest_collection_modifyitems(items):
    """自动为 NOT_DEVICE_TESTS 中的测试添加 not_device_test 标记"""
    for item in items:
        if any(test_name in item.nodeid for test_name in NOT_DEVICE_TESTS):
            item.add_marker(pytest.mark.not_device_test)

def pytest_sessionfinish(session, exitstatus):
    """无测试收集时退出码 5 → 0(避免 CI 失败)"""
    if exitstatus == 5:
        session.exitstatus = 0

2.3 Doctest 自定义

python 复制代码
# 注册自定义 doctest 标志 IGNORE_RESULT
IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT")

class CustomOutputChecker(OutputChecker):
    def check_output(self, want, got, optionflags):
        if IGNORE_RESULT & optionflags:
            return True  # 忽略输出比较
        return OutputChecker.check_output(self, want, got, optionflags)

# 替换 pytest 和 doctest 的默认实现
doctest.OutputChecker = CustomOutputChecker
_pytest.doctest.DoctestModule = HfDoctestModule
doctest.DocTestParser = HfDocTestParser

2.4 PyTorch 环境配置

python 复制代码
if is_torch_available():
    enable_tf32(False)  # CI 中禁用 TF32 以确保数值精度
    # 设置 cuDNN 卷积精度为 IEEE
    if hasattr(torch.backends.cudnn, "conv"):
        torch.backends.cudnn.conv.fp32_precision = "ieee"
    # 补丁 torch.compile:支持 TORCH_COMPILE_FORCE_FULLGRAPH 环境变量
    patch_torch_compile_force_graph()

三、通用模型测试 --- test_modeling_common.py

3.1 ModelTesterMixin --- 模型测试核心 Mixin

这是整个测试体系最重要的类,所有模型的测试类都继承自它:

python 复制代码
class ModelTesterMixin:
    model_tester = None              # 提供配置和输入数据的测试器
    all_model_classes = ()           # 该模型架构的所有模型类
    test_resize_embeddings = True
    test_resize_position_embeddings = False
    test_mismatched_shapes = True
    test_missing_keys = True
    test_torch_exportable = True
    is_encoder_decoder = False
    has_attentions = True
    _is_composite = False
    model_split_percents = [0.5, 0.7, 0.9]

    def __init_subclass__(cls, **kwargs):
        """自动为所有 test_ 方法添加 hub_retry 装饰器"""
        super().__init_subclass__(**kwargs)
        for attr_name in dir(cls):
            if attr_name.startswith("test_"):
                attr = getattr(cls, attr_name)
                if callable(attr):
                    setattr(cls, attr_name, hub_retry()(attr))

    @property
    def all_generative_model_classes(self):
        """过滤出支持生成的模型类"""
        return tuple(mc for mc in self.all_model_classes if mc.can_generate())

__init_subclass__ 的设计意义:

  • 所有继承 ModelTesterMixin 的测试类,其 test_ 方法自动获得 Hub 重试能力
  • 解决 CI 中 Hub 连接不稳定导致的测试失败

3.2 _prepare_for_class --- 输入数据适配

python 复制代码
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
    """根据模型类类型调整输入数据格式"""
    inputs_dict = copy.deepcopy(inputs_dict)
    # 多选模型:增加 num_choices 维度
    if model_class.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
        inputs_dict = {k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
                       if isinstance(v, torch.Tensor) and v.ndim > 1 else v
                       for k, v in inputs_dict.items()}
    # 如果需要标签,根据模型任务类型生成对应标签
    if return_labels:
        if model_class.__name__ in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
            inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=torch_device)
        elif model_class.__name__ in get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES):
            inputs_dict["labels"] = torch.zeros((batch_size, seq_length), dtype=torch.long, device=torch_device)
        # ... 更多任务类型
    return inputs_dict

3.3 关键测试方法

ModelTesterMixin 包含大量通用测试,覆盖模型的核心功能:

测试方法 功能
test_model_get_set_embeddings 测试嵌入层的获取和设置
test_forward_signature 测试 forward 方法签名
test_torch_save_load 测试模型的保存和加载
test_tied_weights_keys 测试权重共享
test_can_use_safetensors 测试 safetensors 格式
test_resize_embeddings 测试嵌入层大小调整
test_model_is_small 确保测试模型足够小
test_eager_matches_sdpa_inference Eager vs SDPA 注意力一致性
test_eager_matches_batched_and_grouped_inference MoE 批处理/分组推理一致性

3.4 Eager vs SDPA 一致性测试

这是最复杂也最重要的测试之一,确保 SDPA(Scaled Dot Product Attention)与 Eager 实现的数值一致性:

python 复制代码
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION = [
    (f"{dtype}_pad_{padding_side}{'' if use_attention_mask else '_no_attn_mask'}"
     f"{'_sdpa_kernels' if enable_kernels else ''}",
     *(dtype, padding_side, use_attention_mask, False, enable_kernels))
    for dtype in ("fp16", "fp32", "bf16")
    for padding_side in ("left", "right")
    for use_attention_mask in (True, False)
    for enable_kernels in (True, False)
] + [("fp32_pad_left_output_attentions", "fp32", "left", True, True, False)]

def _test_eager_matches_sdpa_inference(self, name, dtype, padding_side,
                                        use_attention_mask, output_attentions, enable_kernels):
    """对比 Eager 和 SDPA 注意力实现的输出一致性"""
    for model_class in self.all_model_classes:
        set_seed(42)
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        # 分别加载 Eager 和 SDPA 模型
        model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
        model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")

        # 构造不同 padding 方式的输入
        # ...

        # 对比输出
        results = [torch.allclose(s, e, atol=atol, rtol=rtol) for s, e in zip(logits_sdpa, logits_eager)]
        if np.mean(results) < 0.8:  # 80% 的 batch 元素匹配即可
            raise ValueError(f"mean relative difference for {key}: {mean_relative_diff:.3e}")

容差设计:

  • 不同 dtype/device/内核组合有不同的 atol/rtol 容差
  • bf16 容差更大(1e-2),因为高量级输出时精度下降
  • 80% 匹配率阈值:避免个别异常 token 导致测试失败

3.5 MoE 推理一致性测试

python 复制代码
def _test_eager_matches_batched_and_grouped_inference(self, name, dtype):
    """对比 Eager/Batched_mm/Grouped_mm 三种 MoE 专家计算实现"""
    implementations = ["eager", "batched_mm", "grouped_mm"]
    mocks = {
        "batched_mm": Mock(wraps=batched_mm_experts_forward),
        "grouped_mm": Mock(wraps=grouped_mm_experts_forward),
    }
    # 对比三种实现的输出

四、通用配置测试 --- test_configuration_common.py

4.1 ConfigTester

python 复制代码
class ConfigTester:
    def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
        self.parent = parent              # 测试用例引用(用于断言)
        self.config_class = config_class  # 要测试的配置类
        self.has_text_modality = has_text_modality
        self.inputs_dict = kwargs         # 配置参数
        self.common_properties = common_properties

4.2 测试方法

python 复制代码
def run_common_tests(self):
    """运行所有通用配置测试"""
    self.create_and_test_config_common_properties()        # 验证通用属性存在
    self.create_and_test_config_to_json_string()           # JSON 序列化/反序列化
    self.create_and_test_config_to_json_file()             # JSON 文件保存/加载
    self.create_and_test_config_from_and_save_pretrained() # save_pretrained/from_pretrained
    self.create_and_test_config_from_and_save_pretrained_subfolder()  # 子文件夹支持
    self.create_and_test_config_from_and_save_pretrained_composite()  # 复合配置(VLM等)
    self.create_and_test_config_with_num_labels()          # num_labels 标签映射
    self.check_config_can_be_init_without_params()         # 无参数初始化
    self.check_config_arguments_init()                     # 通用参数初始化
    self.create_and_test_config_from_pretrained_custom_kwargs()  # 自定义 kwargs 覆盖

复合配置测试 --- 针对视觉语言模型等复合架构:

python 复制代码
def create_and_test_config_from_and_save_pretrained_composite(self):
    """测试复合/嵌套配置的加载和保存"""
    config = self.config_class(**self.inputs_dict)
    with tempfile.TemporaryDirectory() as tmpdirname:
        config.save_pretrained(tmpdirname)
        general_config_loaded = self.config_class.from_pretrained(tmpdirname)
        # 遍历所有子配置,验证独立加载与从整体加载一致
        for sub_config_key, sub_class in general_config_loaded.sub_configs.items():
            sub_config_loaded = sub_class.from_pretrained(tmpdirname)
            self.parent.assertEqual(sub_config_loaded.to_dict(), general_config_dict[sub_config_key])

五、通用分词器测试 --- test_tokenization_common.py

5.1 TokenizerTesterMixin --- 分词器测试核心

python 复制代码
class TokenizerTesterMixin:
    tokenizer_class = None                    # 要测试的分词器类
    space_between_special_tokens = False
    from_pretrained_kwargs = None
    from_pretrained_filter = None
    from_pretrained_id = None
    test_seq2seq = True
    test_tokenizer_from_extractor = True
    test_sentencepiece = False
    test_sentencepiece_ignore_case = False

    integration_test_input_string = """This is a test 😊
I was born in 92000, and this is falsé.
生活的真谛是
Hi  Hello
..."""  # 综合测试字符串,覆盖多语言、emoji、特殊字符

    @classmethod
    def setUpClass(cls):
        """加载预训练分词器到临时目录"""
        cls.tmpdirname = tempfile.mkdtemp()
        if cls.from_pretrained_id and cls.tokenizer_class is not None:
            tokenizer = AutoTokenizer.from_pretrained(cls.from_pretrained_id[0], ...)
            tokenizer.save_pretrained(cls.tmpdirname)

5.2 TokenizersExtractor --- 从 tokenizer.json 提取词表

python 复制代码
class TokenizersExtractor:
    """从 tokenizer.json 文件提取词表、合并规则和特殊标记"""
    def extract(self) -> tuple[dict, list, list, list]:
        # 支持多种词表格式:
        # - dict-based: BPE/WordPiece/WordLevel → {token: id}
        # - list-based: Unigram → [[token, score], ...]
        # 提取 merges(BPE 合并规则)
        # 提取 added_tokens(特殊标记)
        return vocab_ids, vocab_scores, merges, added_tokens_decoder

5.3 关键测试方法

TokenizerTesterMixin 提供了全面的分词器测试:

测试类别 方法示例
基本编解码 test_tokenize, test_encode, test_decode
特殊标记 test_special_tokens, test_added_tokens
填充/截断 test_padding, test_truncation
批处理 test_batch_encoding, test_batch_decode
保存/加载 test_save_pretrained, test_from_pretrained
子词采样 check_subword_sampling
集成测试 tokenizer_integration_test_util
提取器测试 test_tokenizer_from_extractor

5.4 集成测试工具

python 复制代码
def tokenizer_integration_test_util(self, expected_encoding, model_name, revision=None,
                                     sequences=None, decode_kwargs=None, padding=True):
    """分词器集成测试:编码→解码→验证一致性"""
    tokenizer = tokenizer_class.from_pretrained(model_name, revision=revision)
    encoding = tokenizer(sequences, padding=padding)
    decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True, **decode_kwargs)
                         for seq in encoding["input_ids"]]
    # 验证编码结果与预期一致
    self.assertDictEqual(encoding.data, expected_encoding)
    # 验证解码后与原文一致
    for expected, decoded in zip(sequences, decoded_sequences):
        self.assertEqual(expected, decoded)

5.5 辅助工具函数

python 复制代码
def use_cache_if_possible(func):
    """装饰器:缓存分词器结果以加速测试,但深拷贝避免状态污染"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        cached = func(*args, **kwargs)
        copied = copy.deepcopy(cached)
        # Rust tokenizer对象无法深拷贝,从原始对象恢复
        if hasattr(cached, "_tokenizer"):
            copied._tokenizer = cached._tokenizer
        return copied
    return wrapper

def check_subword_sampling(tokenizer, text=None):
    """验证子词正则化是否产生不同结果"""
    tokens_list = [tokenizer.tokenize(text) for _ in range(5)]
    # 确认至少有一对不同的分词结果
    subword_sampling_found = any(a != b for a, b in itertools.combinations(tokens_list, 2))

六、测试工具库 --- testing_utils.py

6.1 环境控制装饰器

Transformers 定义了大量条件跳过装饰器,根据运行环境自动跳过不满足条件的测试:

python 复制代码
# 环境变量控制
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_flaky_tests = parse_flag_from_env("RUN_FLAKY", default=True)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
_run_training_tests = parse_flag_from_env("RUN_TRAINING_TESTS", default=True)
_run_tensor_parallel_tests = parse_flag_from_env("RUN_TENSOR_PARALLEL_TESTS", default=True)

# 速度分类装饰器
def slow(test_case):
    """标记慢速测试,默认跳过,需设置 RUN_SLOW=1 运行"""
    return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)

def tooslow(test_case):
    """标记过慢测试,始终跳过(修复前不应留在 CI 中)"""
    return unittest.skip(reason="test is too slow")(test_case)

# 依赖检查装饰器(50+ 个)
def require_torch(test_case): ...
def require_accelerate(test_case, min_version=...): ...
def require_flash_attn(test_case): ...
def require_bitsandbytes(test_case): ...
def require_deepspeed(test_case): ...
def require_torch_gpu(test_case): ...
def require_torch_multi_gpu(test_case): ...
# ... 等等

6.2 TestCasePlus --- 增强测试基类

python 复制代码
class TestCasePlus(unittest.TestCase):
    """扩展 unittest.TestCase,提供:
    1. 路径访问器(自动解析仓库根目录等)
    2. 自动清理的临时目录
    3. 环境变量管理
    """

    def setUp(self):
        self.teardown_tmp_dirs = []
        # 自动解析路径
        self._test_file_path = inspect.getfile(self.__class__)
        self._repo_root_dir = ...  # 向上查找包含 src/ 和 tests/ 的目录
        self._tests_dir = self._repo_root_dir / "tests"
        self._src_dir = self._repo_root_dir / "src"

    def get_auto_remove_tmp_dir(self, tmp_dir=None, before=True, after=True):
        """创建自动清理的临时目录"""
        if tmp_dir is None:
            tmp_dir = tempfile.mkdtemp()
        if after:
            self.teardown_tmp_dirs.append(tmp_dir)
        return tmp_dir

6.3 设备管理

python 复制代码
# 根据环境自动选择测试设备
if is_torch_available():
    if torch.cuda.is_available():
        torch_device = "cuda"
    elif is_torch_xpu_available():
        torch_device = "xpu"
    elif is_torch_npu_available():
        torch_device = "npu"
    elif is_torch_hpu_available():
        torch_device = "hpu"
    else:
        torch_device = "cpu"

6.4 Hub 测试工具

python 复制代码
class TemporaryHubRepo:
    """在 HuggingFace Hub 上创建临时仓库用于测试,测试结束后自动删除"""
    def __init__(self, namespace, repo_name, token=None):
        self.repo_id = f"{namespace}/{repo_name}"
        create_repo(self.repo_id, token=self.token, repo_type="model", exist_ok=True)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        delete_repo(self.repo_id, token=self.token, repo_type="model")

def hub_retry(max_retries=5, base_wait=30):
    """Hub 请求重试装饰器,处理网络不稳定"""
    def decorator(test_func):
        @functools.wraps(test_func)
        def wrapper(*args, **kwargs):
            for attempt in range(max_retries):
                try:
                    return test_func(*args, **kwargs)
                except (httpx.ConnectError, httpx.TimeoutException) as e:
                    if attempt == max_retries - 1:
                        raise
                    wait_time = base_wait * (2 ** attempt)
                    time.sleep(wait_time)
        return wrapper
    return decorator

6.5 其他工具

python 复制代码
class CaptureLogger:
    """捕获日志输出用于断言"""
    def __init__(self, logger):
        self.logger = logger
        self.io = StringIO()
    def __enter__(self):
        self.handler = logging.StreamHandler(self.io)
        self.logger.addHandler(self.handler)
        return self
    def __exit__(self, *args):
        self.logger.removeHandler(self.handler)

class Expectations(UserDict):
    """设备特定的测试期望值管理"""
    # 根据设备属性(GPU型号、内存等)选择不同的容差

class Colors / ColoredFormatter:
    """CI 日志着色"""

class CPUMemoryMonitor:
    """CPU 内存监控,检测内存泄漏"""

七、CLI 测试 --- tests/cli/

7.1 CLI 测试固件

python 复制代码
# tests/cli/conftest.py
@pytest.fixture
def cli():
    """创建 CLI 测试运行器"""
    def _cli_invoke(*args):
        runner = CliRunner()
        # 修补 stdout/stderr.close 避免测试中的关闭错误
        old_out_close = sys.stdout.close
        old_err_close = sys.stderr.close
        sys.stdout.close = _noop
        sys.stderr.close = _noop
        try:
            return runner.invoke(transformers.cli.transformers.app, list(args), catch_exceptions=False)
        finally:
            sys.stdout.close = old_out_close
            sys.stderr.close = old_err_close
    return _cli_invoke

八、CI 配置 --- .circleci/config.yml

8.1 CI 流程概览

复制代码
┌──────────────────────┐
│ check_circleci_user  │ ← 确认运行在 huggingface 组织下
└──────────┬───────────┘
           │
┌──────────▼───────────┐
│    fetch_tests       │ ← 分析 git diff,确定需要运行的测试
│  (tests_fetcher.py)  │ ← 生成动态 CircleCI 配置
└──────────┬───────────┘
           │
┌──────────▼───────────┐
│  continuation/continue│ ← 使用生成的配置继续流水线
└──────────┬───────────┘
           │
     ┌─────┴─────────────────────┐
     │                            │
┌────▼─────┐  ┌──────────────────▼──────────────┐
│check_code│  │  动态生成的模型测试 jobs           │
│_quality  │  │  (每个模型一个 job,并行执行)       │
└──────────┘  └─────────────────────────────────┘

8.2 关键 Job

fetch_tests --- 智能测试选择:

yaml 复制代码
fetch_tests:
    steps:
        - run: python utils/tests_fetcher.py | tee tests_fetched_summary.txt
        - run: python utils/tests_fetcher.py --filter_tests || true
        - run: python .circleci/create_circleci_config.py --fetcher_folder test_preparation
        - continuation/continue:
            configuration_path: test_preparation/generated_config.yml

check_code_quality --- 代码质量检查:

yaml 复制代码
check_code_quality:
    docker:
        - image: huggingface/transformers-quality
    steps:
        - run: uv pip install -e ".[quality,serving]"
        - run: make check-code-quality

check_repository_consistency --- 仓库一致性检查:

yaml 复制代码
check_repository_consistency:
    steps:
        - run: make check-repository-consistency
        # 测试不同依赖组合下的导入
        - run: python -c "from transformers import *"  # 全部后端
        - run: |  # 仅 torch(无 PIL)
            uv pip uninstall Pillow torchvision -q
            python -c "from transformers import *"
        - run: |  # 仅 PIL(无 torch)
            uv pip uninstall torch torchvision torchaudio -q
            python -c "from transformers import *"

8.3 夜间测试

yaml 复制代码
parameters:
    nightly:
        type: boolean
        default: false

fetch_all_tests:
    steps:
        - run: python utils/tests_fetcher.py --fetch_all  # 运行所有测试

九、测试体系设计原理

9.1 Mixin 模式 --- 测试逻辑复用

Transformers 测试体系的核心设计模式是 Mixin 继承链

复制代码
具体模型测试类(如 LlamaModelTest)
    ├── ModelTesterMixin           --- 通用模型测试
    ├── GenerationTesterMixin      --- 生成相关测试
    ├── PipelineTesterMixin        --- Pipeline 测试
    └── unittest.TestCase          --- 标准测试基类

每个 Mixin 提供一组 test_ 方法,具体模型只需:

  1. 定义 model_tester(提供配置和输入数据)
  2. 定义 all_model_classes(列出要测试的模型类)
  3. 可选覆盖特定测试方法

9.2 参数化测试

使用 parameterized 库实现多维度参数化:

python 复制代码
# Eager vs SDPA 测试:3 dtype × 2 padding × 2 attention_mask × 2 kernels = 24 组
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION = [
    (f"{dtype}_pad_{padding_side}...", *(dtype, padding_side, use_attention_mask, ...))
    for dtype in ("fp16", "fp32", "bf16")
    for padding_side in ("left", "right")
    for use_attention_mask in (True, False)
    for enable_kernels in (True, False)
]

9.3 容错与稳定性

  • hub_retry:自动重试 Hub 请求(指数退避)
  • is_flaky:标记已知不稳定测试
  • set_config_for_less_flaky_test/set_model_for_less_flaky_test:调整配置减少测试抖动
  • 80% 匹配率阈值:SDPA 一致性测试允许少量 batch 元素不匹配
  • 设备自适应容差:不同设备/dtype 使用不同的 atol/rtol

9.4 测试分类与过滤

复制代码
测试分类维度:
├── 速度: slow / tooslow / normal
├── 依赖: require_torch / require_flash_attn / require_deepspeed / ...
├── 功能: is_pipeline_test / is_training_test / is_staging_test
├── 设备: not_device_test / require_torch_gpu / require_torch_multi_gpu
├── 特性: torch_compile_test / flash_attn_test / tensor_parallel_ci
└── 环境: RUN_SLOW / RUN_PIPELINE_TESTS / RUN_TRAINING_TESTS / ...

十、模块间关系

复制代码
┌───────────────────────────────────────────────────────────┐
│                     CI 层                                  │
│  .circleci/config.yml                                      │
│  ├── fetch_tests (tests_fetcher.py)                        │
│  ├── check_code_quality (make check-code-quality)          │
│  └── check_repository_consistency                          │
└──────────────────────┬────────────────────────────────────┘
                       │
┌──────────────────────▼────────────────────────────────────┐
│                  pytest 框架层                              │
│  conftest.py                                               │
│  ├── 自定义标记注册                                         │
│  ├── NOT_DEVICE_TESTS 自动标记                              │
│  ├── Doctest 自定义(HfDoctestModule / HfDocTestParser)    │
│  └── PyTorch 环境配置(TF32/cuDNN/torch.compile)           │
└──────────────────────┬────────────────────────────────────┘
                       │
┌──────────────────────▼────────────────────────────────────┐
│                通用测试 Mixin 层                             │
│  test_modeling_common.py    → ModelTesterMixin             │
│  test_configuration_common.py → ConfigTester               │
│  test_tokenization_common.py → TokenizerTesterMixin        │
│  test_pipeline_mixin.py     → PipelineTesterMixin          │
│  test_training_mixin.py     → TrainingTesterMixin          │
│  generation/test_utils.py   → GenerationTesterMixin        │
└──────────────────────┬────────────────────────────────────┘
                       │
┌──────────────────────▼────────────────────────────────────┐
│                测试工具层                                   │
│  testing_utils.py                                          │
│  ├── 环境控制: slow/tooslow/require_* 装饰器               │
│  ├── TestCasePlus: 路径管理/临时目录/环境变量               │
│  ├── CaptureLogger: 日志捕获                               │
│  ├── TemporaryHubRepo: Hub 临时仓库                        │
│  ├── hub_retry: Hub 请求重试                               │
│  └── 设备管理: torch_device 自适应                         │
└──────────────────────┬────────────────────────────────────┘
                       │
┌──────────────────────▼────────────────────────────────────┐
│              具体模型测试层                                  │
│  tests/models/{model_name}/                                │
│  ├── test_modeling_{name}.py                               │
│  │     class XxxModelTest(ModelTesterMixin, unittest.TestCase):
│  │         model_tester = XxxModelTester(...)              │
│  │         all_model_classes = (XxxModel, XxxForCausalLM, ...)│
│  ├── test_configuration_{name}.py                          │
│  └── test_tokenization_{name}.py                           │
└───────────────────────────────────────────────────────────┘

关键依赖链:

  • 具体模型测试 → Mixin 类 → testing_utils.py(装饰器/工具) → conftest.py(全局配置)
  • CI 配置 → tests_fetcher.py(智能选择测试) → create_circleci_config.py(动态生成 CI 配置)
  • ModelTesterMixin.__init_subclass__ 自动为所有测试方法添加 hub_retry,确保 Hub 不稳定时自动重试
相关推荐
来让爷抱一个1 小时前
MonkeyCode 实战:AI 驱动的 GitHub PR 工作流优化
人工智能·开源·ai编程
梦奇不是胖猫1 小时前
《从0到1将 AI核心名词连成线》
人工智能
泠不丁1 小时前
个人数字化效率系统:从 Obsidian 复盘到自动化时间管理的进阶实践
人工智能
专注搞钱1 小时前
半导体MES智能化升级方案:基于机器学习与Transformer大模型落地实战手册
人工智能·机器学习·transformer
专注搞钱1 小时前
【行业思考】半导体CIM+AI+SKILL融合探索|FAB设备智能自动化演进解析
运维·人工智能·自动化
lpd_lt1 小时前
如何让AI生成项目的单元测试,propmt技巧详解
java·人工智能·单元测试·ai编程
俊哥V1 小时前
每日 AI 研究简报 · 2026-06-05
人工智能·ai
Herlie1 小时前
2026小白做小红书封面AI工具指南:3款对比
人工智能
mit6.8241 小时前
Agent思维模式 | 评估
人工智能