【TrOCR】用Transformer和torch库实现TrOCR模型

项目结构:

复制代码
TrOCR/
├── config.py               # 所有配置参数(路径、超参数等)
├── dataset.py              # 数据集类 + 数据增强(合并 data_augmentation)
├── model.py                # 模型加载与配置
├── utils.py                # 通用工具(含日志功能,合并 logger.py)
├── train.py                # 训练逻辑 + 入口(合并 trainer.py)
├── predict.py              # 推理接口 + 入口(合并 inference.py)
├── evaluate.py             # 评估指标 + 入口(合并 metrics.py)
├── requirements.txt        # 依赖库
├── README.md               # 项目说明
├── data/                   # 数据集
│   ├── train/(images + labels.json)
│   ├── val/(images + labels.json)
│   └── test/(images + labels.json)
├── models/                 # 保存训练好的模型
└── logs/                   # 训练日志、评估报告

数据集的数据结构

我的数据集路径:C:\Users\Virgil\Desktop\dataetOCR\ChineseOcr2k,目录下有train和val两个文件夹,分别是images和labels.json。

标签JSON文件的数据结构:labels.json内容是这样的:

json 复制代码
    [
      {
        "file_name": "20587062_124836763.jpg",
        "text": "设施一流的绿色、舒适"
      },
      {
        "file_name": "20487921_757563219.jpg",
        "text": "瘦削,二○○六年五月"
      },
      {
        "file_name": "20567468_1494490742.jpg",
        "text": "分置改革方案》在法规"
    ...

损失函数:

TrOCR 官方使用的损失函数是交叉熵损失(Cross-Entropy Loss),

主要用于计算解码器生成文本与真实标签之间的差异,具体是通过 标签移位(label shifting) 策略实现的序列到序列(Seq2Seq)损失计算。

TrOCR 是典型的编码器 - 解码器架构(图像编码器 + 文本解码器),

其损失计算逻辑与大多数 Seq2Seq 模型一致:

  • 输入与标签设计:解码器的输入是 "真实文本标签左移一位 + 起始符号(如 [CLS])",标签是 "真实文本标签 + 终止符号(如 [SEP])"。
  • 损失计算:对解码器每个时间步的输出 logits 计算交叉熵损失,忽略 padding 位置(通过将 pad token 替换为 -100 实现,PyTorch 会自动忽略 -100 标签的损失)。