从零开始使用GOT-OCR2.0——多模态OCR项目:微调数据集构建 + 训练(解决训练报错,成功实验微调训练)

在上一篇文章记录了GOT-OCR项目的环境配置和基于官方模型参数的基础使用。环境安装的博文快速链接:

从零开始使用GOT-OCR2.0------多模态通用型OCR(非常具有潜力的开源OCR项目):项目环境安装配置 + 测试使用-CSDN博客

本章在环境配置好的基础上,进一步研究官方给出的微调模型接口。在其官方源码论文中介绍了项目的整体架构------编码器和解码器,其中编码器使用的 ViTDet 的视觉Transformer模型,解码器使用的是通义千问的Qwen-0.5B语言模型。**官方便捷接口只支持 Post-Train 仅训练解码器。**具体可参加精读论文,快速链接地址:

【论文精读】GOT-OCR2.0源码论文------打破传统OCR流程的多模态视觉-语言大模型架构:预训练VitDet 视觉模型+ 阿里通义千问Qwen语言模型-CSDN博客

GOT-OCR项目官方Github地址:https://github.com/Ucas-HaoranWei/GOT-OCR2.0

目录

一、微调解码器的数据集格式

1.训练集的目录结构

2.训练集标签JSON内容格式

3.具体代码实现

4.样例展示

二、官方指令解码器训练

1.报错解决

2.官方指令训练(修改增添版)

3.测试结果(微调模型与官方原始模型效果对比)

4.实验中的新发现


一、微调解码器的数据集格式

节省内容时间:可以直接跳转到本节第4部分样例展示------自行编程转化到样例格式就可以开始训练了

官方指引下有三个位置是注意点,这些地方都可以在下载的源码中查看。

一是,官方给出的样例标签 json 文件格式 ,在下载的本地源码GOT-OCR2.0-main\assets\train_sample.jpg 位置就可以查看。

二和三是需要修改源码py代码文件,分别在本地下载源码的 GOT-OCR-2.0-master\GOT\utils\constants.pyGOT-OCR-2.0-master\GOT\data\conversation_dataset_qwen.py 位置。

结合这三个需要修改的地方,可以分析判断出数据集的大致构成结构。

1.训练集的目录结构

首先是训练数据集的文件结构 。通过查看 constants.py 中的参数配置信息 ,并且结合另一个文件conversation_dataset_qwen.py37 行修改内容(官方指引),可以得知GOT-OCR解码器微调支持多个数据集的输入(如包括pdf文件数据集、场景图片ocr数据集等),详见下图。

示例数据给出了三个数据集的配置,每个数据集可以在不同的地址位置,数据集应包含图片数据的总文件夹,和一个json标签文件包含对应图片的所有标签训练信息。

在另一个代码修改处展示了如何指定不同类型数据集训练,如 pdf 数据的训练包含指定的data1和data2两个数据集。

因此结合上述所有内容,可以指定,最终数据集生成后的目录结构样子应该如下。为了更清晰,这里把所有图片都放在 images 子目录下,在外面写一个 data1.json 保存标签信息。

(还应注意 到,官方这里没有做训练集和验证集 的区分,在其使用指南中,后续有专门的验证模型val介绍,因此推测此处应该全部数据都是用来训练的,实际中还应留出一部分数据用来后面验证模型有没有过拟合或其他泛化性问题)

2.训练集标签JSON内容格式

然后就要具体看其标签文件(data1.json)的内容格式组成了。官方只给了一个样例图片,可以从中大致解读出大概结构。

首先所有标签信息的数据构成一个大的列表 保存在JSON文件中,列表中每个元素是对应每一张不同图片数据的标签信息字典

字典结构中,键**"image"** 的值对应图片名,键**"conversations"** 的值对应文本结果。至于conversations中为啥有两个信息,且一个是 "from:human" ,一个**"from:gpt"** ,有两种猜测:推测一------是可能数据OCR文本结果存在两种来源,一种是人工标注的,一种是gpt生成的;推测二------论文中表示其数据集是从官方数据集中抽取的,是不是"gpt"表示官方数据中的内容,而"human"代表还需要后续人工效验?

总而言之,本文数据集构建不具体探究其中原因,因为样例数据给出中文本内容基本都是在**"gpt"的value值** 中,所以暂定将自己的数据集文本也填入,而不去改变"human"的内容

3.具体代码实现

下面给出具体实现标签文件编写的代码。实际上只需将自己的数据集标签内容改写成上述分析的JSON格式即可。

需要注意下述代码的标签文件是来自于对页面中每个字的分类目标检测结果,其原始的JSON格式如下,其中前四个是坐标信息,第五个是字的排序信息,表示该字在整段文章中对应第几个字,最后是这个字的具体字符。

因此,如果是其他数据集格式,只需最终生成上一标题记录的标签JSON内容格式即可,可自行编程。

python 复制代码
import os
import json

class GOT_Dataset_Creator:
    def __init__(self, dataset_path):
        self.dataset = dataset_path

    def generate_labels(self, image_path, label_path):
        save_list = []
        for file in os.listdir(image_path):
            if file.lower().endswith('jpg'):
                img_dict = {}
                img_dict['image'] = "images/"+file
                base = file.split('.')[0]
                labeldir = os.path.join(label_path, base+'.json')
                with open(labeldir, 'r', encoding='utf-8') as f:
                    word_list = json.load(f)
                if word_list!= []:
                    conv_list = [{"from":"human","value":"<image>\nOCR: "}]
                    textlst = [w[5] for w in word_list]
                    text = ''.join(textlst)
                    text_dict = {"from":"gpt", "value":text}
                    conv_list.append(text_dict)
                    img_dict["conversations"] = conv_list
                    save_list.append(img_dict)
                else:
                    continue
        #print(save_list)
        savedir = os.path.join(self.dataset, 'data1.json')
        with open(savedir, 'w') as f:
            json.dump(save_list,f)


if __name__=='__main__':
    image_path = \dataset\images  # 原始图片路径
    label_path = \dataset\labels # 原始标签路径
    dataset_path = \GOT-OCR_Dataset\dataset # GOT数据集保存路径

    gotdataset = GOT_Dataset_Creator(dataset_path)
    gotdataset.generate_labels(image_path, label_path)

4.样例展示

最终数据结构和之前分析的保持一致,然后还要修改代码,下图给出样例。

(特别注意)在实际训练中还有几处需要特别注意的细节。下面给出具体图片展示。

首先是构建好数据集,数据集目录格式和标签文本内容格式如下两图。

然后修改源码两个文件conversation_dataset_qwen.pyconstants.py 内容。下图标注了注意细节点。

二、官方指令解码器训练

1.报错解决

在使用官方给出的指令训练时,会出现一些报错,下面记录其解决。

首先是,官方使用的是 deepspeed 库来进行Transformer模型的微调加速训练,但是该库只支持 Linux 系统下运行, Windows系统运行会直接报错(找不到"deepspeed"指令等报错信息)。因此GOT 项目模型的微调训练只支持Linux环境!

然后是,关于CUDA环境的问题,在本文训练微调时,有一个库------**"bitsandbytes"**总会报找不到CUDA路径的问题,如下图所示。可能原因是,本机服务器安装的是CUDA12.4最新版本,而环境安装的是cu118。

一种粗糙的解决办法是,替换这个库的版本,原版GOT中要求的是0.41.0版本,我更新到0.44.1版本后,虽然训练中还是会报错,但是不影响正常训练。

python 复制代码
# 卸载原有包
pip uninstall bitsandbytes
# 安装新包
pip install bitsandbytes==0.44.1

最后其他报错,如果在训练中遇到有什么包导入失败,显示"no module ...",直接在命令行pip install安装即可。

解决以上问题后,可以开始训练。

2.官方指令训练(修改增添版)

官方给出的指令并不能完全满足实际训练要求,这里新增了指定训练GPU的操作。下面先给出实际训练中,要根据自己文件数据修改的地方。

下面一图总览。代码如下。

python 复制代码
# 微调GOT解码器修改后的指令代码

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 deepspeed path/GOT/train/train_GOT.py --deepspeed path/zero_config/zero2.json --model_name_or_path path/GOT_weights/ --use_im_start_end True --bf16 True --gradient_accumulation_steps 2 --evaluation_strategy "no" --save_strategy "steps" --save_steps 200 --save_total_limit 1 --weight_decay 0. --warmup_ratio 0.001 --lr_scheduler_type "cosine" --logging_steps 1 --tf32 True --model_max_length 8192 --gradient_checkpointing True --dataloader_num_workers 8 --report_to none --per_device_train_batch_size 2 --num_train_epochs 1 --learning_rate 2e-5 --datasets gj-ocr --output_dir path/output

3.测试结果(微调模型与官方原始模型效果对比)

安装上述指令测试结果如下。

训练过程信息打印如下。

结果保存模型参数数据如下。可以在自己指定的文件夹下找到。

继续可以把训练好的模型参数下载到要进行测试的目录,可以对比一下官方默认模型参数和经过微调后的模型识别部署效果

本文微调使用的是7000+张繁体古文数据集(特别的,这里的古文数据集没有做任何标点分隔处理,每张图片对应一个"长文本":无标点的文本),训练轮次使用10,其余训练参数默认官方。

先看官方模型参数的OCR结果。可以看到官方的模型参数已经表现不错了,大部分以识别出来了,但是仍存在少部分字识别分类错误------对应右图红框内的字

而且官方的训练集在文本部分是有"\t,\n"等如空格换行的分隔符的,因此其识别的OCR结果也对应有换行的效果,这也再次表现了多模态模型的优越性和强大的空间感知能力。

再看经过自建数据集微调后的结果。可以看到字分类的正确率得到了明显提升------出错的为下图黄框内的字。并且语言模型的"可塑性"看起来很强,对应训练集对文本进行分割符,部署效果也是没有进行分段分行的。

4.实验中的新发现

对于密集型文字的识别效果,微调后的模型和官方模型对比差异就更加明显了。

同时,本文实验发现,对于中分辨率的大图(大于1280*1280,小于6000*6000),可以直接构建数据集训练,也能得到很好的效果。

官方模型在训练集中采用的都是1280*1280的方形图,而本文使用的数据集微调则不同------1.图片数据是不规则的,可能是竖直或平躺的矩形;2.图片数据大小不是固定的,可能是1345*2895这种的。

下图对比展示了这种由训练数据集差异导致的部署效果差异。可以看到官方模型是无法直接识别这种长宽大于1280*1280且文字密集的数据(官方解决办法是对原图进行切分小图识别)

但是经过微调后的模型就可以得到很好的效果了。如下图所示。

由此可见,GOT-OCR项目非常具有潜力,其性能还能继续深挖。ViT加语言模型的架构还有非常大的模型性能空间值得深入研究。

相关推荐
Donvink17 分钟前
Transformers在计算机视觉领域中的应用【第3篇:Swin Transformer——多层次的Vision Transformer】
人工智能·深度学习·目标检测·计算机视觉·transformer
一尘之中2 小时前
基于Transformer的编码器-解码器图像描述模型在AMD GPU上的应用
人工智能·深度学习·transformer
数据猿17 小时前
【金猿人物展】白鲸开源CEO郭炜:未来数据领域的PK是大模型Transformer vs 大数据Transform...
大数据·人工智能·深度学习·transformer
shuxunAPI19 小时前
身份证 OCR 识别 API 接口的应用场景
云计算·ocr·api·csdn开发云
次次皮19 小时前
【方案三】JAVA中使用ocr(Umi-OCR)
java·ocr
shuxunAPI20 小时前
营业执照 OCR 识别 API 的发展前景
云计算·ocr·api·csdn开发云
deephub1 天前
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
pytorch·深度学习·transformer·变长序列
weixin_402939991 天前
【深度学习】transformer的encoder部分,多特征多变量,双头,一头回归,一头分类的代码实现,并且分开embedding的
深度学习·回归·transformer
yunmoon012 天前
一款支持80+语言,包括:拉丁文、中文、阿拉伯文、梵文等开源OCR库
开源·ocr