TrOCR模型微调

参考连接【Transformers-Tutorials/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_native_PyTorch.ipynb

1.根据任务类型构建数据集

python 复制代码
import torch
from torch.utils.data import Dataset
from PIL import Image


class ORCDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=256):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text 
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding
        
        
cache_dir = "./pretrain"
# 你可以按照这个方法先缓存到本地
# from transformers import VisionEncoderDecoderModel
# model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed",cache_dir='./trocr-base-printed')
# 或者你直接去官网下,然后都放一个文件夹

processor = TrOCRProcessor.from_pretrained(cache_dir)
train_dataset = ORCDataset(root_dir='',
                           df=train_df,
                           processor=processor)
eval_dataset = ORCDataset(root_dir='',
                           df=test_df,
                           processor=processor)
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=16)

2.加载模型

python 复制代码
from transformers import VisionEncoderDecoderModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained(cache_dir)
model.to(device)
print(model.encoder)


# 这里可以设置一些分层学习率的操作
#decoder_param_id = list(map(id,model.decoder.parameters()))
# encoder_params = filter(lambda p: id(p) not in decoder_param_id, model.parameters())
# encoder_params = model.encoder.parameters()
# decoder_params = model.decoder.parameters()

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

3.加载评价指标

python 复制代码
from datasets import load_metric

cer_metric = load_metric("cer")
# import datasets
# 加载本地 CER 度量
# cer_metric = datasets.load_metric("./cer.py")

def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

4.原生Pytorch训练流程

python 复制代码
from tqdm import tqdm
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for epoch in range(100):  # loop over the dataset multiple times
   # train
   model.train()
   train_loss = 0.0
   for batch in tqdm(train_dataloader):
      # get the inputs
      for k,v in batch.items():
        batch[k] = v.to(device)

      # forward + backward + optimize
      outputs = model(**batch)
      loss = outputs.loss
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      train_loss += loss.item()

   print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))
    
   # evaluate
   model.eval()
   valid_cer = 0.0
   with torch.no_grad():
     for batch in tqdm(eval_dataloader):
       # run batch generation
       outputs = model.generate(batch["pixel_values"].to(device))
       # compute metrics
       cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
       valid_cer += cer 

   print("Validation CER:", valid_cer / len(eval_dataloader))

model.save_pretrained(".")
相关推荐
yanxing.D6 小时前
OpenCV轻松入门_面向python(第六章 阈值处理)
人工智能·python·opencv·计算机视觉
JJJJ_iii7 小时前
【机器学习01】监督学习、无监督学习、线性回归、代价函数
人工智能·笔记·python·学习·机器学习·jupyter·线性回归
Python图像识别10 小时前
71_基于深度学习的布料瑕疵检测识别系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
python·深度学习·yolo
千码君201611 小时前
React Native:从react的解构看编程众多语言中的解构
java·javascript·python·react native·react.js·解包·解构
淮北49411 小时前
windows安装minicoda
windows·python·conda
哥布林学者12 小时前
吴恩达深度学习课程一:神经网络和深度学习 第三周:浅层神经网络(二)
深度学习·ai
weixin_5195357712 小时前
从ChatGPT到新质生产力:一份数据驱动的AI研究方向指南
人工智能·深度学习·机器学习·ai·chatgpt·数据分析·aigc
爱喝白开水a13 小时前
LangChain 基础系列之 Prompt 工程详解:从设计原理到实战模板_langchain prompt
开发语言·数据库·人工智能·python·langchain·prompt·知识图谱
生命是有光的13 小时前
【深度学习】神经网络基础
人工智能·深度学习·神经网络
信田君952714 小时前
瑞莎星瑞(Radxa Orion O6) 基于 Android OS 使用 NPU的图片模糊查找APP 开发
android·人工智能·深度学习·神经网络