【TrOCR】根据任务特性设计词表vocab.json

预训练权重的限制

基于对microsoft/trocr-base-printed预训练模型文件夹中的各个文件的理解,思考以下问题:

1、你觉得哪些文件在TrOCR的训练中起到作用,并简短说明是什么作用。

2、哪些文件经过TrOCR的训练调参后发生改变(也就是说通过训练得到的,而不是人物设计的)。

3、一些对于TrOCR模型的训练和推理都没有作用的文件,是在哪里起作用,是识别完文字的下一个阶段用到?

4、训练不改变内容,在仅仅用于文字识别TrOCR任务的训练推理时,不起作用的文件是否就可以删除?

5、预训练权重是这些,那微调后的权重文件夹内容也是这些吗?

1、在TrOCR训练中起作用的文件及作用

  • config.json:定义模型整体架构(编码器/解码器层数、维度等),是训练时模型结构的"蓝图",决定网络参数的初始化和计算逻辑。
  • preprocessor_config.json:规定图像预处理规则(缩放、归一化等),确保训练数据的输入格式与模型要求一致。
  • tokenizer_config.jsonvocab.jsonmerges.txtspecial_tokens_map.json:共同定义文本分词逻辑(词表、子词合并规则、特殊标记),将训练文本转换为模型可理解的Token ID。
  • model.safetensors:存储初始预训练权重,训练中会基于新数据更新这些权重参数。
  • generation_config.json:控制训练时文本生成的策略(如起始/结束标记、生成长度限制),影响模型的拟合方向。

2、训练调参后发生改变的文件(训练得到,非人工设计)

model.safetensors

该文件存储模型的所有权重参数(编码器/解码器的注意力权重、层归一化参数等),训练过程中通过反向传播不断更新这些参数,是唯一由训练直接生成的文件。其他文件(如配置文件、分词文件)均为人工设计的"规则/结构定义",训练中不会被修改。

3、对训练和推理无直接作用的文件及作用场景

  • README.md:是模型的说明文档,用于指导用户理解模型用途、使用方法、引用信息等,不参与训练/推理计算,仅在用户学习和使用模型时提供参考。
  • .gitattributes :是Git版本控制的配置文件,用于指定大文件(如model.safetensors)由Git LFS管理,确保仓库高效存储和协作,与模型的训练/推理功能无关,仅作用于代码仓库的版本管理环节。

这些文件与"文字识别后的下阶段"无关,主要服务于模型的分发、管理和用户指导。

4、训练不改变内容且对识别任务无直接作用的文件是否可删除?

不建议删除。

  • README.md 是快速上手模型的关键参考,删除后用户难以了解模型用法和限制;
  • .gitattributes 缺失会导致Git无法正确处理大文件(如model.safetensors可能被当作普通文件存储,导致仓库膨胀),影响后续版本管理和协作。
    仅在极端极简部署场景(如嵌入式设备),且明确无需版本管理和用户指导时,可考虑删除,但需谨慎。

5、预训练权重与微调后权重的文件夹内容是否相同?

基本相同,核心文件结构一致。

微调仅更新 model.safetensors (权重参数),可能会调整 generation_config.json(若微调时修改了生成策略),但其他文件(config.json、分词文件、预处理配置等)均保持不变------因为微调是在预训练模型的结构基础上调整权重,不会改变模型的架构、数据处理规则等"底层设计"。

模型前向传播

查看模型结构和词表长度

现在我print出了加载预训练权重模型结构:

python 复制代码
from transformers import VisionEncoderDecoderModel

PRETRAINED_MODEL_PATH = r"C:\Users\Virgil\Desktop\TrOCR\trocr-base-printed"
model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH)

print("模型结构")
print(model)

print("------------------------------------------")
model.config.vocab_size = model.config.decoder.vocab_size

print("词表长度:")
print(model.config.vocab_size)
复制代码
VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): ViTOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (pooler): ViTPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (decoder): TrOCRForCausalLM(
    (model): TrOCRDecoderWrapper(
      (decoder): TrOCRDecoder(
        (embed_tokens): TrOCRScaledWordEmbedding(50265, 1024, padding_idx=1)
        (embed_positions): TrOCRLearnedPositionalEmbedding(514, 1024)
        (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (layers): ModuleList(
          (0-11): 12 x TrOCRDecoderLayer(
            (self_attn): TrOCRAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (activation_fn): GELUActivation()
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (encoder_attn): TrOCRAttention(
              (k_proj): Linear(in_features=768, out_features=1024, bias=True)
              (v_proj): Linear(in_features=768, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
    )
    (output_projection): Linear(in_features=1024, out_features=50265, bias=False)
  )
)

打印出的词表长度:50265

前向传播与输出logits讲解

每一轮训练

python 复制代码
        # 前向传播(获取解码器输出logits)
        outputs = model(pixel_values=pixel_values, labels=labels)
        logits = outputs.logits  
        # 形状: (batch_size, max_length, vocab_size)

在TrOCR训练的每一轮前向传播中,model(pixel_values=pixel_values, labels=labels) 是核心计算过程,结合模型结构可拆解为以下步骤,最终得到的 logits 是模型预测的核心输出:

1. model()的输入解析

  • pixel_values :预处理后的图像张量,形状通常为 (batch_size, 3, 384, 384)(批量大小、RGB通道、图像尺寸)。由 preprocessor_config.json 定义的规则处理(缩放、归一化等),确保与编码器输入要求一致。
  • labels :图像对应的文本标签经分词后的Token ID张量,形状为 (batch_size, max_length)max_length 为文本最大长度,短文本用 <pad> 填充)。用于计算预测损失,指导模型参数更新。

2. 模型前向传播过程(结合结构细节)

模型按"编码器→解码器"流程处理输入,最终输出 logits

(1)编码器(ViTModel)处理图像
  • 图像分块与嵌入
    输入图像先经 ViTEmbeddings.patch_embeddings(16x16卷积)分割为 24x24=576 个补丁(384/16=24),每个补丁被投影为 768 维向量(与 config.json 中编码器 hidden_size=768 对应),再添加位置嵌入并经 dropout 处理。
  • Transformer编码
    嵌入后的补丁序列(576个向量)输入 ViTEncoder 的12层 ViTLayer
    • 每层通过 ViTAttention(多头自注意力)捕捉补丁间的空间关系;
    • ViTIntermediate(3072维前馈网络)和 ViTOutput 转换特征,配合两层LayerNorm(layernorm_before/after)稳定训练。
  • 输出图像特征
    编码器最终输出形状为 (batch_size, 576, 768) 的图像特征(批量、补丁数、特征维度),作为解码器的"视觉输入"。
(2)解码器(TrOCRForCausalLM)生成文本预测

解码器以"图像特征+文本标签"为输入,通过自回归逻辑生成文本预测:

  • 文本嵌入
    labels(文本Token ID)先经 TrOCRScaledWordEmbedding 转换为 1024 维词向量(与解码器 d_model=1024 对应),添加 TrOCRLearnedPositionalEmbedding(位置信息)并经LayerNorm处理。
  • Transformer解码(12层 TrOCRDecoderLayer
    每层包含两个关键注意力机制:
    • 自注意力(self_attn) :关注已生成的文本Token(如第i步关注前i-1个Token),通过 k_proj/q_proj/v_proj 计算1024维注意力分布,捕捉文本序列的上下文关系;
    • 交叉注意力(encoder_attn) :关注编码器输出的图像特征(768维),通过 k_proj/v_proj 将图像特征转换为1024维(与解码器维度匹配),再与文本特征计算注意力,确保生成的文本与图像内容对齐。
      每层还通过 fc1(1024→4096)和 fc2(4096→1024)的前馈网络增强特征表达,配合LayerNorm稳定训练。
  • 输出投影
    解码器最终输出的文本特征(1024维)经 output_projection(线性层)映射到词表空间,得到 logits

3. logits 的含义与形状

  • 形状(batch_size, max_length, vocab_size),对应 (批量大小, 文本长度, 词表大小)。其中 vocab_size=50265(与 vocab.json 词表大小一致)。
  • 含义logits[i, j, k] 表示"第i个样本、第j个位置的Token为词表中第k个词"的原始预测分数(未归一化的概率)。后续会通过softmax转换为概率分布,用于计算与 labels 的交叉熵损失(衡量预测与真实标签的差异)。

这一步前向传播的核心是:编码器将图像转换为视觉特征,解码器结合视觉特征和文本标签生成逐位置的词表预测(logits)logits 既是预测结果的直接输出,也是计算损失、反向传播更新 model.safetensors 中权重参数的关键依据,最终让模型学习到"图像→文本"的映射规律。

词表设计思路与分析

词表vocab文件里的各个token很多都不是单个字母或者单个数字。这个OCR任务就是单纯的识别出是什么字母数字或者汉字。那词表没必要这么复杂。可以设计为仅保留 "单个字符 + 核心特殊 Token" 的词表设计(总规模几千个)

  • 核心特殊 Token
  • 常用的几千个中文汉字,
  • 52个英文字母
  • 10个数字
  • 各个特殊符号

这种仅保留"单个字符+核心特殊Token"的词表设计(总规模几千个)在特定场景下有其合理性,但也会带来显著的优缺点,且需要配套调整模型和训练策略,具体分析如下:

一、这种设计的核心优势

  1. 词表规模大幅减小

    从50265缩减到几千,直接降低模型的词嵌入层参数规模 (原解码器词嵌入层参数为 50265×1024≈5100万,新设计为 5000×1024≈512万,参数减少90%),可降低内存占用和计算成本,适合资源有限的场景(如嵌入式设备)。

  2. 分词逻辑简化

    无需BPE分词的复杂合并规则(merges.txt 可省略),文本直接按单个字符拆分(如"apple"拆分为 a→p→p→l→e),分词过程更直观,减少因子词拆分导致的识别误差(如特殊组合被错误拆分)。

  3. 适配"纯字符识别"场景

    若OCR任务仅针对高度标准化的文本(如身份证、发票上的固定格式字符,无复杂缩写、生僻组合),这种词表可覆盖所有可能输入,避免子词表带来的冗余。

二、潜在问题与局限性

  1. 丢失"上下文依赖"捕捉能力

    原词表的子词(如"app""ing")本质是对"高频字符组合"的预编码,模型可通过这些子词快速学习字符间的关联(如"q"后几乎必接"u")。而纯字符词表中,每个字符独立拆分,模型需要从头学习所有组合规律,对模糊、倾斜或噪声图像的鲁棒性会下降(例如"th"在模糊图像中,子词表模型可能通过整体特征识别,而纯字符模型可能因单个"t"或"h"模糊导致错误)。

  2. 预训练权重无法直接复用

    原预训练模型的解码器(基于RoBERTa)是为50265词表设计的,其词嵌入层、注意力权重等参数均与该词表绑定。换成新词表后:

    • 词嵌入层维度不匹配(需重新初始化或裁剪,导致预训练的"语言知识"丢失);
    • 解码器的交叉注意力、前馈网络等参数是基于子词分布训练的,换成纯字符后,参数分布与新数据不匹配,需要从头训练或大规模微调,成本显著增加。
  3. 对复杂文本的覆盖不足

三、适用场景与建议

  • 适合场景:高度标准化、字符有限的封闭场景(如快递单号识别、车牌识别,字符集固定且无复杂组合)。
  • 不适合场景:通用OCR(如书籍、网页、手写体),或包含复杂文本结构的场景。

若坚持使用这种设计,需配套调整:

  1. 重新初始化解码器的词嵌入层(适配新词表大小);
  2. 增加训练数据中字符级别的增强(如模糊、旋转、噪声),帮助模型学习单个字符的鲁棒特征;
  3. 移除原词表相关文件(merges.txt 不再需要),仅保留字符映射的 vocab.json 和特殊Token配置。

这种"极简字符词表"是对特定场景的优化,而非通用方案。它通过牺牲对复杂文本的适应能力,换取参数规模和计算效率的提升,但需要重新训练模型,且无法复用原预训练的子词级语言知识。是否采用,需根据具体任务的文本复杂度和资源约束决定。

词表生成

52个英文字母和10个数字很简单,

特殊Token就直接用预训练权重的前4个特殊token:

python 复制代码
  "<s>": 0,
  "<pad>": 1,
  "</s>": 2,
  "<unk>": 3,

汉字和特殊符号需要找:

5020 个常用汉字资源文件:收录于 "常用汉字大全.txt",项目地址为 https://gitcode.com/Open-source-documentation-tutorial/4afa4
https://github.com/wy-luke/All-Chinese-Character-Set/blob/main/symbols.txt

https://github.com/DenverCoder1/latex-gboard-dictionary/blob/master/dictionary.txt

python 复制代码
import json

def read_chars_from_file(file_path):
    """从文件中读取字符,每行一个字符"""
    chars = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                # 去除每行的换行符,但保留原始字符(包括空格、tab等)
                char = line.rstrip('\n')
                if char:  # 确保不添加空字符串
                    chars.append(char)
    except FileNotFoundError:
        print(f"警告:未找到文件 {file_path},将跳过该文件的处理")
    return chars

# 存储所有被跳过的字符
skipped_chars = []

# 1. 定义核心特殊token(按优先级排序)
special_tokens = [
    "<s>",    # 起始标记
    "<pad>",  # 填充标记
    "</s>",   # 结束标记
    "<unk>"   # 未知标记
]
special_count = len(special_tokens)
special_start_idx = 0
special_end_idx = special_count - 1

# 2. 从文件读取字符
symbols = read_chars_from_file("symbols_2.txt")
symbol_count = len(symbols)
symbol_start_idx = special_end_idx + 1 if special_count > 0 else 0
symbol_end_idx = symbol_start_idx + symbol_count - 1 if symbol_count > 0 else -1

math_chars = read_chars_from_file("math.txt")
original_math_count = len(math_chars)

chinese_chars = read_chars_from_file("chinese_5021.txt")
original_chinese_count = len(chinese_chars)

# 3. 处理数学符号:去除与特殊符号重复的字符
unique_math_chars = []
symbol_set = set(symbols)
skipped_math_count = 0

for char in math_chars:
    if char not in symbol_set and char not in unique_math_chars:
        unique_math_chars.append(char)
    else:
        skipped_math_count += 1
        skipped_chars.append(char)  # 记录被跳过的数学符号

math_count = len(unique_math_chars)
math_start_idx = symbol_end_idx + 1 if symbol_count > 0 else special_end_idx + 1
math_end_idx = math_start_idx + math_count - 1 if math_count > 0 else -1

# 4. 处理中文字符:确保唯一性
existing_chars = set(special_tokens + symbols + unique_math_chars)  # 前面字符集合
chinese_internal_duplicates = 0  # 中文内部重复计数
chinese_external_duplicates = 0  # 与其他文件重复计数
unique_chinese_chars = []

for char in chinese_chars:
    if char in existing_chars:
        # 与前面的特殊符号/数学符号等重复
        chinese_external_duplicates += 1
        skipped_chars.append(char)  # 记录与其他文件重复的中文
    elif char in unique_chinese_chars:
        # 中文内部重复
        chinese_internal_duplicates += 1
        skipped_chars.append(char)  # 记录中文内部重复的字符
    else:
        unique_chinese_chars.append(char)
        existing_chars.add(char)

# 总跳过数量 = 内部重复 + 外部重复
skipped_chinese_count = chinese_internal_duplicates + chinese_external_duplicates

chinese_count = len(unique_chinese_chars)
chinese_start_idx = math_end_idx + 1 if math_count > 0 else symbol_end_idx + 1
chinese_end_idx = chinese_start_idx + chinese_count - 1 if chinese_count > 0 else -1

# 5. 合并所有字符,保持指定顺序
all_chars = special_tokens + symbols + unique_math_chars + unique_chinese_chars
total_count = len(all_chars)

# 6. 生成词表字典(字符到索引的映射)
vocab = {char: idx for idx, char in enumerate(all_chars)}

# 7. 保存为JSON文件
with open("vocab_2.json", 'w', encoding='utf-8') as f:
    json.dump(vocab, f, ensure_ascii=False, indent=2)

# 8. 保存所有被跳过的字符到skip.txt
with open("skip.txt", 'w', encoding='utf-8') as f:
    for char in skipped_chars:
        f.write(char + '\n')

# 9. 输出统计信息
print("=" * 50)
print("词表生成统计信息:")
print("=" * 50)
print(f"特殊Token:共 {special_count} 个")
print(f"  索引区间:[{special_start_idx}, {special_end_idx}]")
print(f"  内容:{special_tokens}")
print("-" * 50)
print(f"特殊符号:共 {symbol_count} 个")
print(f"  索引区间:[{symbol_start_idx}, {symbol_end_idx}]")
print("-" * 50)
print(f"数学符号:")
print(f"  原始数量:{original_math_count} 个")
print(f"  写入数量:{math_count} 个(去重后)")
print(f"  跳过数量:{skipped_math_count} 个(与特殊符号重复)")
print(f"  索引区间:[{math_start_idx}, {math_end_idx}]")
print("-" * 50)
print(f"中文字符:")
print(f"  原始数量:{original_chinese_count} 个")
print(f"  写入数量:{chinese_count} 个(去重后)")
print(f"  与其他文件重复:{chinese_external_duplicates} 个")
print(f"  中文内部重复:{chinese_internal_duplicates} 个")
print(f"  总跳过数量:{skipped_chinese_count} 个")
print(f"  索引区间:[{chinese_start_idx}, {chinese_end_idx}]")
print("-" * 50)
print(f"所有被跳过的字符已保存到:skip.txt(共 {len(skipped_chars)} 个)")
print(f"词表总字符数:{total_count} 个")
print(f"词表已保存为:vocab_3.json")
print("=" * 50)
复制代码
==================================================
词表生成统计信息:
==================================================
特殊Token:共 4 个
  索引区间:[0, 3]
  内容:['<s>', '<pad>', '</s>', '<unk>']
--------------------------------------------------
特殊符号:共 128 个
  索引区间:[4, 131]
--------------------------------------------------
数学符号:
  原始数量:919 个
  写入数量:708 个(去重后)
  跳过数量:211 个(与特殊符号重复)
  索引区间:[132, 839]
--------------------------------------------------
中文字符:
  原始数量:5021 个
  写入数量:2501 个(去重后)
  与其他文件重复:2520 个
  中文内部重复:0 个
  总跳过数量:2520 个
  索引区间:[840, 3340]
--------------------------------------------------
所有被跳过的字符已保存到:skip.txt(共 2731 个)
词表总字符数:3341 个
词表已保存为:vocab_3.json
==================================================

修改model模型文件

vocab.json换成新的同名的词表文件
config.json需要修改词表大小:

python 复制代码
    "vocab_size": 3340

generation_config.json文件不用修改,因为四个特殊token的id都没有变。
tokenizer_config.json文件不用修改

python 复制代码
{
  "_from_model_config": true,
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "eos_token_id": 2,
  "pad_token_id": 1,
  "transformers_version": "4.27.0.dev0",
  "use_cache": false
}

权重文件model.safetensors

TrOCR(基于 Transformer 的 OCR 模型)的核心组件(如解码器的嵌入层和输出层)的维度与词表大小强相关:

  • 嵌入层(Embedding Layer):输入维度为vocab_size,输出维度为模型隐藏层大小(如 768)。预训练模型的嵌入层权重是基于原词表的,若新词表大小(3340)与原词表不同,嵌入层的参数维度会不匹配,直接加载会报错。

    • 词嵌入层实际路径是 model.decoder.model.decoder.embed_tokens(嵌套在 model→decoder 层级下)
  • 输出层(Output Layer):输出维度为vocab_size(用于预测每个字符的概率),预训练的输出层权重同样依赖原词表大小,新词表大小变化后,输出层参数维度也会不匹配。

    • 输出层实际路径是 model.decoder.output_projection(直接在解码器下)
相关推荐
ALex_zry1 小时前
JSON::Value 功能详解:从三目运算符到高级用法
json
嵌R式小Z3 天前
JSON&cJSON
json
tan77º5 天前
【项目】分布式Json-RPC框架 - 项目介绍与前置知识准备
linux·网络·分布式·网络协议·tcp/ip·rpc·json
Yn3126 天前
在 Python 中使用 json 模块的完整指南
开发语言·python·json
陈涛5758 天前
5个最好用的 JSON 工具推荐:让数据处理变得简单高效
json
bkspiderx9 天前
pb2json.hpp 文档:Protobuf 与 JSON 通用转换工具类
json·protobuf·protobuf与json转换
万粉变现经纪人10 天前
何解决PyCharm中pip install安装Python报错ModuleNotFoundError: No module named ‘json’问题
python·pycharm·json·beautifulsoup·scikit-learn·matplotlib·pip
晨欣10 天前
orjson 与 json:实战对比与选型指南(含示例)(GPT-5 回答)
gpt·json
Pi_Qiu_11 天前
Python初学者笔记第二十二期 -- (JSON数据解析)
笔记·python·json