一. 前言
大模型与OCR(光学字符识别)技术的结合为文档处理和图像识别带来了革命性的进步。目前市面上有多种支持OCR功能的大模型,可以根据不同需求进行选择和使用。
二.示例代码
直接上代码,代码示例为合合接口,可参照,修改配置可直接使用
python
import asyncio
from concurrent.futures import ThreadPoolExecutor
import requests
from bs4 import BeautifulSoup
import json
import zipfile
import io
import shutil
import os
import logging
import traceback
import time
# 管理OCR(光学字符识别)
class OCR_CONFIG:
def __init__(self, host, port, timeout, ocr_cache):
self.host = host
self.port = port
# 超时时间
self.timeout = timeout
# 缓存路径
self.ocr_cache = ocr_cache
OCR_HOST = '192.168.1.127'
OCR_PORT = 43109
TIMEOUT = 300
OCR_CACHE = f'./test_output'
# 实例化2-类2:赋值IP和端口等信息
OCR_CONFIG_HH = OCR_CONFIG(OCR_HOST, OCR_PORT, TIMEOUT, OCR_CACHE)
# 获取 host 和 port
HH_IP = OCR_CONFIG_HH.host
# 原 HH_POST 改为 HH_PORT
HH_PORT = OCR_CONFIG_HH.port
# 类1
class OcrDetectorHH(object):
# 路径设置
def __init__(self, cache_dir: str = None) -> None:
# 调用实例2-类2:获取IP和端口等信息
HH_IP = OCR_CONFIG_HH.host
HH_PORT = OCR_CONFIG_HH.port
self.semaphore = asyncio.Semaphore(1)
# 设置缓存目录:如果未传入 cache_dir 参数,则使用 OCR_CONFIG_HH.ocr_cache ;否则使用传入的 cache_dir
self.cache_dir = OCR_CONFIG_HH.ocr_cache if not cache_dir else cache_dir
self.upload_documents_url = f"http://{HH_IP}:{HH_PORT}/api/v3/parser/external/task/create"
self.get_parse_result_url = f"http://{HH_IP}:{HH_PORT}/api/v3/parser/external/result"
self.download_zip_url = f"http://{HH_IP}:{HH_PORT}/api/v3/parser/external/md_file/export"
self.download_json_url = f"http://{HH_IP}:{HH_PORT}/api/v3/parser/external/json_file/export"
self.download_excel_url = f"http://{HH_IP}:{HH_PORT}/api/v3/parser/external/excel_file/export"
self.download_source_url = f"http://{HH_IP}:{HH_PORT}/api/v3/parser/external/source_file/export"
def upload_documents(self, pdf_path):
params = {"parse_type": "document", "merge_images": 1, }
try:
with open(pdf_path, "rb") as pdf_file:
files = {"documents": (pdf_file.name, pdf_file, "application/octet-stream"),}
# 设置请求头,指定接受JSON格式的响应
headers = {"accept": "application/json",}
print(f"\n{'-'*55}ocr{'-'*55}\n 正在上传文件: {pdf_path},生成ID.......")
# 发送POST请求上传文件,response为:<Response [200]>
response = requests.post(self.upload_documents_url, params=params, headers=headers, files=files)
if response.status_code == 200:
result = response.json()
if result["code"] == 200 and "task_ids" in result["data"]:
return result["data"]["task_ids"][0]
else:
# 打印并抛出错误信息
errinfo = f'获取ocr任务id失败-原因: 状态码200 但-{result.get("msg", "Unknown error")}'
print(errinfo)
raise Exception(errinfo)
else:
errinfo = f'获取ocr任务id失败-原因: 状态码异常 {response.status_code}: {response.text}'
print(errinfo)
raise Exception(errinfo)
except Exception as e:
print(f'获取ocr任务id失败-原因: {traceback.format_exc()}', end='\n\n')
raise
return None
# 类1:函数9 根据任务 task_id('7f59a1a278c34423a9e38f537fad729c') 获取解析结果转换为 json
def get_parse_result(self, task_id):
headers = {"accept": "application/json","Content-Type": "application/json"}
payload = {"task_id": task_id}
try:
response = requests.post(self.get_parse_result_url, json=payload, headers=headers)
if response.status_code == 200:
result = response.json()
if result.get("code") == 10702:
return None
if result.get("code") == 10703:
raise Exception(f'ocr-10703-failed-{result}')
return result
else:
print(f'获取ocr结果请求异常-原因: 状态码异常 {response.status_code}: {response.text}', end='\n\n')
except Exception as e:
print(f'获取ocr结果请求异常-原因: {traceback.format_exc()}', end='\n\n')
return None
# 类1函数8:通用的文件下载函数:通过HTTP接口下载文件并保存到本地
def download_file(self, task_id, url, save_path):
headers = {"accept": "application/octet-stream","Content-Type": "application/json"}
body = {"task_ids": [task_id]}
try:
response = requests.post(url, json=body, headers=headers)
if response.status_code == 200:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "wb") as file:
file.write(response.content)
print(f"\n{'-'*55}ocr{'-'*55}\n OCR识别完成,识别输出内容从合合接口下载文件: {save_path}")
else:
print(f'从合合接口下载文件异常- 合合任务id- {task_id} - 原因: 状态码异常 {response.status_code}: {response.text}', end='\n\n')
raise
except Exception as e:
print(f'从合合接口下载文件异常- 合合任务id- {task_id} - 原因: {traceback.format_exc()}', end='\n\n')
raise Exception(f'从合合接口下载文件异常 {response.status_code}: {response.content}', 10009)
# 类1函数7调用类1函数8:下载并保存 Markdown(压缩文件)
def download_zip(self, task_id, save_path):
self.download_file(task_id, self.download_zip_url, save_path)
# 类1函数6调用类1函数8:下载并保存 JSON 文件
def download_json(self, task_id, save_path):
self.download_file(task_id, self.download_json_url, save_path)
# 类1函数5调用类1函数8:载并保存 excel 文件
def download_excel(self, task_id, save_path):
self.download_file(task_id, self.download_excel_url, save_path)
# 类类1函数4调用类1函数8:下载并保存原文件
def download_source(self, task_id, save_path):
self.download_file(task_id, self.download_source_url, save_path)
# 类1:函数3将单个OCR初步识别结果 html和.html 转译处理后文件转换为 Markdown 并保存,删除过程的压缩文件
# retry_interval=5:轮询间隔秒数(默认5秒,控制服务器请求频率)
def get_result_with_id(self, task_id, save_path, filename, unzip_path, retry_interval=5):
# 内部功能3:将HTML表格转换为Markdown"
def html_table_to_markdown(html_str):
# 将HTML字符串解析为DOM对象
soup = BeautifulSoup(html_str, 'html.parser')
# 定位第一个<table>标签
table = soup.find('table')
# 获取所有<tr>行标签
rows = table.find_all('tr')
# 记录跨行/列单元格的占位信息,格式:{(行索引, 列索引): (文本内容, 剩余行数, 剩余列数)}
span_map = {}
# 二维数组,存储处理后的单元格数据
grid = []
# 记录表格最大列数,用于后续补齐短行
max_cols = 0
# 构建网格,拆分合并单元格,复制填充内容
# 遍历每一行
for r, row in enumerate(rows):
# 当前行的单元格集合
grid_row = []
# 当前列指针
c = 0
# 获取所有单元格(含表头th和普通td)
cells = row.find_all(['th', 'td'])
# 处理单元格占位
for cell in cells:
# 当前位置已被跨行单元格占用
while (r, c) in span_map:
# 取出占位内容
text, rem_rows, rem_cols = span_map.pop((r, c))
# 填充内容
grid_row.append(text)
# 更新下一行的占位信息
if rem_rows > 1:
span_map[(r + 1, c)] = (text, rem_rows - 1, rem_cols)
# 指针右移
c += 1
# 获取单元格属性
# 提取单元格文本并去除首尾空格
text = cell.get_text(strip=True)
# 跨列数(默认为1)
colspan = int(cell.get('colspan', 1))
# 跨行数(默认为1)
rowspan = int(cell.get('rowspan', 1))
# 复制 colspan 次,处理跨列:重复填充相同内容
for _ in range(colspan):
grid_row.append(text)
# # 处理跨行:注册后续行的占位信息,注册 rowspan 占位
if rowspan > 1:
# 遍历受影响的行
for i in range(1, rowspan):
# 遍历受影响的列, # 记录占位坐标与内容
for j in range(colspan):
span_map[(r + i, c + j)] = (text, rowspan - i, colspan)
# 列指针移动跨列数
c += colspan
# 处理行尾占位
while (r, c) in span_map:
text, rem_rows, rem_cols = span_map.pop((r, c))
grid_row.append(text)
if rem_rows > 1:
span_map[(r + 1, c)] = (text, rem_rows - 1, rem_cols)
c += 1
max_cols = max(max_cols, len(grid_row))
grid.append(grid_row)
# 补齐每行列数
for row in grid:
if len(row) < max_cols:
row.extend([row[-1]] * (max_cols - len(row)))
# 构造 Markdown 文本
md_lines = []
header = grid[0]
md_lines.append('| ' + ' | '.join(header) + ' |')
md_lines.append('|' + ' --- |' * len(header))
for row in grid[1:]:
md_lines.append('| ' + ' | '.join(row) + ' |')
return '\n'.join(md_lines)
# 内部功能2:将单个OCR初步识别结果 html和.html 转译处理后文件转换为 Markdown 并保存
def process_html_file(html_path):
with open(html_path, 'r', encoding='utf-8') as f:
html_str = f.read()
# 内部功能2调用内部功能3将HTML表格转换为Markdown
md_text = html_table_to_markdown(html_str)
##生成Markdown文件名和路径
# md_filename: page9_table0.md
md_filename = os.path.splitext(os.path.basename(html_path))[0] + '.md'
md_path = os.path.join(os.path.dirname(html_path), md_filename)
# 写入Markdown文件
with open(md_path, 'w', encoding='utf-8') as f:
f.write(md_text)
print(f"\n{'-'*55}ocr{'-'*55}\n [INFO] OCR初步识别结果清洗为 .json 格式文件后,保存为为文件-->: {md_path}")
def merger_table_md(root_dir='mds'):
print(f"\n{'-'*55}ocr{'-'*55}\n 遍历 mds 文件夹,每个文件清洗后保存为 .json 格式")
for dirpath, _, filenames in os.walk(root_dir):
for filename in filenames:
# 将文件名转为小写(统一格式,避免大小写敏感问题)
lower = filename.lower()
# 筛选包含"table"且以.html或.htm结尾的文件
if 'table' in lower and lower.endswith(('.html', '.htm')):
# 内部功能1调用内部功能2:将单个 HTML 文件转换为 Markdown 并保存
process_html_file(os.path.join(dirpath, filename))
if not task_id:
print(f'获取合合ocr结果异常 - 无效的task_id', end='\n\n')
return None
# 轮询获取解析结果,直到任务完成计算最大查询次数 max_retries
max_retries = int(OCR_CONFIG_HH.timeout) // retry_interval
# 开始进入循环 max_retries = 300/5 =60 多次发起请求
print(f"\n{'-'*55}ocr{'-'*55}\n OCR识别正在根据 ID-->{task_id} 发起多次ocr识别")
for attempt in range(max_retries):
print(f"\n{'-'*55}ocr{'-'*55}\n 第{attempt}次发起OCR解析结果请求------有data则输出结果")
parse_result = self.get_parse_result(task_id)
# 如果解析结果中包含 'data'
if parse_result and "data" in parse_result:
# zip_path为:./test_output/作业票使用说明/作业票使用说明.zip
zip_path = os.path.join(save_path, f"{filename}.zip")
self.download_zip(task_id, zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(unzip_path)
# 拼接路径和解压后储存 .md 文件的文件夹 mds
mds_path = os.path.join(unzip_path, 'mds')
print(f"\n{'-'*55}ocr{'-'*55}\n 解压后OCR初步识别结果存贮路径为:{mds_path}")
# 调用内部功能1 保存文件夹 'mds' 里的 名包含 'table' 的 .htm和.html 文件"
merger_table_md(mds_path)
# 将已经处理完的 zip_path 文件移除
os.remove(zip_path)
print(f"\n{'-'*55}ocr{'-'*55}\n OCR识别的初步结果清洗储存已完成,原压缩文件zip_path已移除-->:{zip_path}")
break # 如果获取到结果,退出循环
time.sleep(retry_interval) # 每5秒重试一次
# 类1:函数2
async def detect_async(self, pdf_path, dataset_name="file-ocr"):
async with self.semaphore:
res = await asyncio.to_thread(self.detect, pdf_path, dataset_name)
return res
# 类1:函数1 功能:功能定位 该方法是OCR文本检测流程的核心实现,主要完成PDF文件的结构化解析处理,最终输出包含文本内容和元数据的JSON文件。
async def detect(self, pdf_path, dataset_name="file-ocr"):
basename = os.path.basename(pdf_path)
filename, suffix = os.path.splitext(basename)
# 调用 类1:函数10 返回 task_id为:170a8e3436ba4b9c8722e59b2875552b
task_id = self.upload_documents(pdf_path)
save_path = os.path.join(self.cache_dir, filename)
if not os.path.exists(save_path):
try:
os.makedirs(save_path)
except Exception as e:
print(e)
unzip_path = os.path.join(save_path, filename)
# 调用类1函数3,将OCR初步解析结果清洗解析为合并为 一份 .json 文件保存,删除过程的压缩文件
self.get_result_with_id(task_id, save_path, filename, unzip_path)
# 构建路径文件路径hho_json_path为:./test_output/作业票使用说明/作业票使用说明/作业票使用说明.json.hho
hho_json_path = os.path.join(unzip_path, f'{filename}.json.hho')
# 调用类1函数6:下载并保存 JSON 文件
self.download_json(task_id, hho_json_path)
# hho_file_path为:./test_output/作业票使用说明/作业票使用说明/作业票使用说明.pdf.hho
hho_file_path = os.path.join(unzip_path, f'{basename}.hho')
# 将 pdf_path 指定的源文件复制到 hho_file_path 指定的目标路径
shutil.copy(pdf_path, hho_file_path)
# md_save_path为:./test_output/作业票使用说明/作业票使用说明/作业票使用说明.md
md_save_path = os.path.join(unzip_path, f'{filename}.md')
json_save_path = os.path.join(unzip_path, f'{filename}.json')
json_list = []
json_list.append(json_save_path)
# 合合的格式不一致,自定义保存json文件,以只读模式打开指定路径的 Markdown文件,并将文件内容读取到变量 filename_md 中
with open(md_save_path, 'r', encoding='utf-8') as f:
filename_md = f.read()
# 构造JSON键值对
filename_json = {f'{filename}.md': filename_md}
# mds_dir为:./test_output/作业票使用说明/作业票使用说明/mds
mds_dir = os.path.join(unzip_path, 'mds')
if os.path.exists(mds_dir) and os.path.isdir(mds_dir):
for table_filename in os.listdir(mds_dir):
if os.path.splitext(table_filename)[1] == '.md':
with open(os.path.join(mds_dir, table_filename),'r', encoding='utf-8') as f:
filename_json[table_filename] = f.read()
with open(json_list[-1],'w', encoding='utf-8') as f:
json.dump(filename_json, f, ensure_ascii=False)
return filename, filename_json
if __name__ == '__main__':
# 实例化1-类 1 :设置好IP地址,端口和关联路径。记录程序开始执行的绝对时间戳
t1 = time.time()
MinerU_model = OcrDetectorHH()
doc = 'test.pdf'
# 调用 实例化1:调用类1:函数1 detect() 方法处理文档,返回两个值:filename:处理后生成的文件名 。filename_json:OCR解析结果的JSON数据
filename, filename_json = MinerU_model.detect(doc, "test_minio_only")
# 将Python字典 filename_json 以JSON格式写入到当前目录下的 hh_ocr_res.json 文件中,支持非ASCII字符(如中文)的存储。
with open(f'./{filename}_ocr_res.json', 'w', encoding='utf-8') as f:
json.dump(filename_json, f, ensure_ascii=False)
print(f"\n{'-'*55}ocr{'-'*55}\n OCR识别{filename}文件结果:type为-->:{type(filename_json)},keys为-->:{filename_json.keys()},len为-->:{len(filename_json)}")
print(f"\n{'-'*55}ocr{'-'*55}\n OCR识别结果文件保存为:./{filename}_ocr_res.json")
t2 = time.time()
print(f"\n{'-'*55}ocr{'-'*55}\n OCR识别{filename}文件用时-->:{t2 - t1}")
三. 总结
大模型OCR技术正在快速发展,为文档数字化和智能信息提取提供了强大的技术支撑,开发者可以根据具体需求选择合适的模型进行部署和应用。
以上就是关于【大模型与OCR】配合应用的示例demo使用,希望对你有所帮助!