项目结构:
TrOCR/
├── config.py # 所有配置参数(路径、超参数等)
├── dataset.py # 数据集类 + 数据增强(合并 data_augmentation)
├── model.py # 模型加载与配置
├── utils.py # 通用工具(含日志功能,合并 logger.py)
├── train.py # 训练逻辑 + 入口(合并 trainer.py)
├── predict.py # 推理接口 + 入口(合并 inference.py)
├── evaluate.py # 评估指标 + 入口(合并 metrics.py)
├── requirements.txt # 依赖库
├── README.md # 项目说明
├── data/ # 数据集
│ ├── train/(images + labels.json)
│ ├── val/(images + labels.json)
│ └── test/(images + labels.json)
├── models/ # 保存训练好的模型
└── logs/ # 训练日志、评估报告
数据集的数据结构
我的数据集路径:C:\Users\Virgil\Desktop\dataetOCR\ChineseOcr2k
,目录下有train和val两个文件夹,分别是images和labels.json。
标签JSON文件的数据结构:labels.json内容是这样的:
json
[
{
"file_name": "20587062_124836763.jpg",
"text": "设施一流的绿色、舒适"
},
{
"file_name": "20487921_757563219.jpg",
"text": "瘦削,二○○六年五月"
},
{
"file_name": "20567468_1494490742.jpg",
"text": "分置改革方案》在法规"
...
损失函数:
TrOCR 官方使用的损失函数是交叉熵损失(Cross-Entropy Loss),
主要用于计算解码器生成文本与真实标签之间的差异,具体是通过 标签移位(label shifting) 策略实现的序列到序列(Seq2Seq)损失计算。
TrOCR 是典型的编码器 - 解码器架构(图像编码器 + 文本解码器),
其损失计算逻辑与大多数 Seq2Seq 模型一致:
- 输入与标签设计:解码器的输入是 "真实文本标签左移一位 + 起始符号(如 [CLS])",标签是 "真实文本标签 + 终止符号(如 [SEP])"。
- 损失计算:对解码器每个时间步的输出 logits 计算交叉熵损失,忽略 padding 位置(通过将 pad token 替换为 -100 实现,PyTorch 会自动忽略 -100 标签的损失)。