深度学习系列71:表格检测和识别

1. pdf处理

如果是可编辑的pdf格式,那么可以直接用pdfplumber进行处理:

import pdfplumber
import pandas as pd

with pdfplumber.open("中新科技:2015年年度报告摘要.PDF") as pdf:
    page = pdf.pages[1]   # 第一页的信息
    text = page.extract_text()
    print(text)
    table = page.extract_tables()
    for t in table:
        # 得到的table是嵌套list类型,转化成DataFrame更加方便查看和分析
        df = pd.DataFrame(t[1:], columns=t[0])
        print(df)

如果是图片格式的pdf,可以使用pdf2image库将pdf转为图片后再继续后面的流程:

from pdf2image import convert_from_path
img = np.array(convert_from_path(path, dpi=800, use_cropbox=True)[0])

2. 表格位置检测

2.1 使用ppstructure

使用paddleocr库中的ppstructure可以方便获取表格位置,参考代码:

from paddleocr import PPStructure
structure = table_engine(source_img)

2.2 使用tabledetector

import tabledetector as td
result = td.detect(pdf_path="pdf_path", type="bordered", rotation=False, method='detect')

2.3 使用cv2的图形学方法

调试简单,具体代码如下:

  1. 二值化去除水印
  2. 使用getStructuringElement获取纵线和横线
  3. 两者合并,使用findContours获取表格外边框和内部单元格

3. 位置确认

获取所有单元格后,使用下面的函数获取单元格的相对位置关系:

from typing import Dict, List, Tuple
import numpy as np

class TableRecover:
    def __init__(
        self,
    ):
        pass

    def __call__(self, polygons: np.ndarray) -> Dict[int, Dict]:
        rows = self.get_rows(polygons)
        longest_col, each_col_widths, col_nums = self.get_benchmark_cols(rows, polygons)
        each_row_heights, row_nums = self.get_benchmark_rows(rows, polygons)
        table_res = self.get_merge_cells(
            polygons,
            rows,
            row_nums,
            col_nums,
            longest_col,
            each_col_widths,
            each_row_heights,
        )
        return table_res

    @staticmethod
    def get_rows(polygons: np.array) -> Dict[int, List[int]]:
        """对每个框进行行分类,框定哪个是一行的"""
        y_axis = polygons[:, 0, 1]
        if y_axis.size == 1:
            return {0: [0]}

        concat_y = np.array(list(zip(y_axis, y_axis[1:])))
        minus_res = concat_y[:, 1] - concat_y[:, 0]

        result = {}
        thresh = 5.0
        split_idxs = np.argwhere(minus_res > thresh).squeeze()
        if split_idxs.ndim == 0:
            split_idxs = split_idxs[None, ...]

        if max(split_idxs) != len(minus_res):
            split_idxs = np.append(split_idxs, len(minus_res))

        start_idx = 0
        for row_num, idx in enumerate(split_idxs):
            if row_num != 0:
                start_idx = split_idxs[row_num - 1] + 1
            result.setdefault(row_num, []).extend(range(start_idx, idx + 1))

        # 计算每一行相邻cell的iou,如果大于0.2,则合并为同一个cell
        return result

    def get_benchmark_cols(
        self, rows: Dict[int, List], polygons: np.ndarray
    ) -> Tuple[np.ndarray, List[float], int]:
        longest_col = max(rows.values(), key=lambda x: len(x))
        longest_col_points = polygons[longest_col]
        longest_x = longest_col_points[:, 0, 0]

        theta = 10
        for row_value in rows.values():
            cur_row = polygons[row_value][:, 0, 0]

            range_res = {}
            for idx, cur_v in enumerate(cur_row):
                start_idx, end_idx = None, None
                for i, v in enumerate(longest_x):
                    if cur_v - theta <= v <= cur_v + theta:
                        break

                    if cur_v > v:
                        start_idx = i
                        continue

                    if cur_v < v:
                        end_idx = i
                        break

                range_res[idx] = [start_idx, end_idx]

            sorted_res = dict(
                sorted(range_res.items(), key=lambda x: x[0], reverse=True)
            )
            for k, v in sorted_res.items():
                if v[0]==None or v[1]==None:
                    continue

                longest_x = np.insert(longest_x, v[1], cur_row[k])
                longest_col_points = np.insert(
                    longest_col_points, v[1], polygons[row_value[k]], axis=0
                )

        # 求出最右侧所有cell的宽,其中最小的作为最后一列宽度
        rightmost_idxs = [v[-1] for v in rows.values()]
        rightmost_boxes = polygons[rightmost_idxs]
        min_width = min([self.compute_L2(v[3, :], v[0, :]) for v in rightmost_boxes])

        each_col_widths = (longest_x[1:] - longest_x[:-1]).tolist()
        each_col_widths.append(min_width)

        col_nums = longest_x.shape[0]
        return longest_col_points, each_col_widths, col_nums

    def get_benchmark_rows(
        self, rows: Dict[int, List], polygons: np.ndarray
    ) -> Tuple[np.ndarray, List[float], int]:
        leftmost_cell_idxs = [v[0] for v in rows.values()]
        benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1]

        theta = 10
        # 遍历其他所有的框,按照y轴进行区间划分
        range_res = {}
        for cur_idx, cur_box in enumerate(polygons):
            if cur_idx in benchmark_x:
                continue

            cur_y = cur_box[0, 1]

            start_idx, end_idx = None, None
            for i, v in enumerate(benchmark_x):
                if cur_y - theta <= v <= cur_y + theta:
                    break

                if cur_y > v:
                    start_idx = i
                    continue

                if cur_y < v:
                    end_idx = i
                    break

            range_res[cur_idx] = [start_idx, end_idx]

        sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True))
        for k, v in sorted_res.items():
            if v[0]==None or v[1]==None:
                continue

            benchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1])

        each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist()

        # 求出最后一行cell中,最大的高度作为最后一行的高度
        bottommost_idxs = list(rows.values())[-1]
        bottommost_boxes = polygons[bottommost_idxs]
        max_height = max([self.compute_L2(v[3, :], v[0, :]) for v in bottommost_boxes])
        each_row_widths.append(max_height)

        row_nums = benchmark_x.shape[0]
        return each_row_widths, row_nums

    @staticmethod
    def compute_L2(a1: np.ndarray, a2: np.ndarray) -> float:
        return np.linalg.norm(a2 - a1)

    def get_merge_cells(
        self,
        polygons: np.ndarray,
        rows: Dict,
        row_nums: int,
        col_nums: int,
        longest_col: np.ndarray,
        each_col_widths: List[float],
        each_row_heights: List[float],
    ) -> Dict[int, Dict[int, int]]:
        col_res_merge, row_res_merge = {}, {}
        merge_thresh = 20
        for cur_row, col_list in rows.items():
            one_col_result, one_row_result = {}, {}
            for one_col in col_list:
                box = polygons[one_col]
                box_width = self.compute_L2(box[3, :], box[0, :])

                # 不一定是从0开始的,应该综合已有值和x坐标位置来确定起始位置
                loc_col_idx = np.argmin(np.abs(longest_col[:, 0, 0] - box[0, 0]))
                merge_col_cell = max(sum(one_col_result.values()), loc_col_idx)

                # 计算合并多少个列方向单元格
                for i in range(merge_col_cell, col_nums):
                    col_cum_sum = sum(each_col_widths[merge_col_cell : i + 1])
                    if i == merge_col_cell and col_cum_sum > box_width:
                        one_col_result[one_col] = 1
                        break
                    elif abs(col_cum_sum - box_width) <= merge_thresh:
                        one_col_result[one_col] = i + 1 - merge_col_cell
                        break
                else:
                    one_col_result[one_col] = i + 1 - merge_col_cell + 1

                box_height = self.compute_L2(box[1, :], box[0, :])
                merge_row_cell = cur_row
                for j in range(merge_row_cell, row_nums):
                    row_cum_sum = sum(each_row_heights[merge_row_cell : j + 1])
                    # box_height 不确定是几行的高度,所以要逐个试验,找一个最近的几行的高
                    # 如果第一次row_cum_sum就比box_height大,那么意味着?丢失了一行
                    if j == merge_row_cell and row_cum_sum > box_height:
                        one_row_result[one_col] = 1
                        break

                    elif abs(box_height - row_cum_sum) <= merge_thresh:
                        one_row_result[one_col] = j + 1 - merge_row_cell
                        break
                else:
                    one_row_result[one_col] = j + 1 - merge_row_cell + 1

            col_res_merge[cur_row] = one_col_result
            row_res_merge[cur_row] = one_row_result

        res = {}
        for i, (c, r) in enumerate(zip(col_res_merge.values(), row_res_merge.values())):
            res[i] = {k: [cc, r[k]] for k, cc in c.items()}
        return res

调用代码如下:

h_min = 10
h_max = 5000

def sortContours(cnts, method='left-to-right'):
    reverse = False
    i = 0
    if method == "right-to-left" or method == "bottom-to-top":
        reverse = True
    if method == "top-to-bottom" or method == "bottom-to-top":
        i = 1
    boundingBoxes = [cv2.boundingRect(c) for c in cnts]
    (cnts, boundingBoxes) = zip(*sorted(zip(cnts, boundingBoxes),key=lambda b: b[1][i], reverse=reverse))
    return (cnts, boundingBoxes)

def sorted_boxes(dt_boxes):
    num_boxes = dt_boxes.shape[0]
    dt_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
    _boxes = list(dt_boxes)
    for i in range(num_boxes - 1):
        for j in range(i, -1, -1):
            if (
                abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10
                and _boxes[j + 1][0][0] < _boxes[j][0][0]
            ):
                _boxes[j], _boxes[j + 1] = _boxes[j + 1], _boxes[j]
            else:
                break
    return _boxes

def getBboxDtls(raw):
    ######### 1. 获得表格的边框,确保merge正确展示了图中的表格边框
    gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
    binary = 255-cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY)[1]
    rows, cols = binary.shape
    scale = 30
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (cols // scale, 1))
    eroded = cv2.erode(binary, kernel, iterations=1)
    dilated_col = cv2.dilate(eroded, kernel, iterations=1)
    scale = 20
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, rows // scale))
    eroded = cv2.erode(binary, kernel, iterations=1)
    dilated_row = cv2.dilate(eroded, kernel, iterations=1)
    merge = cv2.add(dilated_col, dilated_row)
    kernel = np.ones((3,3),np.uint8)
    merge = cv2.erode(cv2.dilate(merge, kernel, iterations=3), kernel, iterations=3)
    plt.figure(figsize=(60,30))
    io.imshow(merge[1500:2500])
    
    ########## 2. 获取表格坐标
    tableData = []
    contours = cv2.findContours(merge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = contours[0] if len(contours) == 2 else contours[1]
    contours, boundingBoxes = sortContours(contours, method='top-to-bottom')
    # 获取表格外边框
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        if (h>h_min):
            tableData.append((x, y, w, h))        
    # 获取表格内部的单元格
    contours, hierarchy = cv2.findContours(merge, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    contours, boundingBoxes = sortContours(contours, method="top-to-bottom")
    boxes = []
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        if (h>h_min) and (h<h_max):
            boxes.append([x, y, w, h])

    ########## 3. 计算表格单元格位置关系
    bboxDtls = {}
    for tableBox1 in tableData:
        key = tableBox1
        values = []
        for tableBox2 in boxes:
            x2, y2, w2, h2 = tableBox2
            if tableBox1[0] <= x2 <= tableBox1[0] + tableBox1[2] and tableBox1[1] <= y2 <= tableBox1[1] + tableBox1[3]:
                values.append(tableBox2)
        bboxDtls[key] = values
    for key, values in bboxDtls.items():
        x_tab, y_tab, w_tab, h_tab = key
        for box in values:
            x_box, y_box, w_box, h_box = box
    return bboxDtls

4. 表格文字识别

4.1 wired_table_rec

可以尝试使用wired_table_rec进行识别:

from wired_table_rec import WiredTableRecognition
table_rec = WiredTableRecognition()
table_str = table_rec(cv2.imread(img_path))[0]
HTML(table_str)

4.2 rapidocr_onnxruntime或者pytessect

或者可以使用更原子化的ocr服务,逐个单元格进行ocr识别,完整代码如下:

"""
首先安装pdf2image和rapidocr_onnxruntime两个库。
图像处理部分的参数和代码可以自行调整:
1. pad参数用于去除图片的边框
2. 转pdf时,有时候800dpi会失败,因此需要加入try except
3. 图片太小时ocr效果不好,因此做了resize。这里的3000可以自行调整。
4. 大图片做了二值化处理,目的是去除水印的干扰。这里的180也可以尝试自行调整。
5. 只处理第一个单元格总数大于50的表格。如果要识别图片中所有表格,可修改代码。
6. 返回的是html格式的表格,可以用pd.read_html函数转为dataframe
"""
rocr = RapidOCR()
rocr.text_det.preprocess_op = DetPreProcess(736, 'max')
def getResult(path,pad = 20, resize_thresh=3000, binary_thresh=180):
    if 'pdf' in path:
        try:
            source_img = np.array(convert_from_path(path, dpi=800, use_cropbox=True)[0])[pad:-pad,pad:-pad]
        except:
            source_img = np.array(convert_from_path(path, dpi=300, use_cropbox=True)[0])[pad:-pad,pad:-pad]
    else:
        source_img = cv2.imread(path)[pad:-pad,pad:-pad]
    if source_img.shape[1] < resize_thresh:
        source_img =cv2.resize(source_img,(resize_thresh,int(source_img.shape[0]/source_img.shape[1]*resize_thresh)))
    img = cv2.threshold(cv2.cvtColor(source_img, cv2.COLOR_BGR2GRAY), binary_thresh, 255, cv2.THRESH_BINARY)[1] 
    bboxDtls = getBboxDtls(source_img)
    boxes = []
    table = None
    # 寻找到第一个单元格数大于50的表后停止
    for k,v in bboxDtls.items():
        if len(v)>50:
            table = k
            for r in tqdm(v[1:]):
                res = rocr(img[r[1]: r[1]+r[3],r[0]:r[2]+r[0]])[0]
                if res!=None:
                    res.sort(key = lambda x:(x[0][0][1]//(img.shape[1]//20),x[0][0][0]//(img.shape[0]//20)))
                    boxes.append([[r[0], r[1]], [r[0], r[1]+r[3]],[r[0]+r[2], r[1]+r[3]], [r[0]+r[2], r[1]], ''.join([t[1].replace('\n','').replace(' ','') for t in res])])
                else:
                    boxes.append([[r[0], r[1]], [r[0], r[1]+r[3]],[r[0]+r[2], r[1]+r[3]], [r[0]+r[2], r[1]], ''])  
            break
    polygons = sorted_boxes(np.array(boxes))
    texts = [p[4] for p in polygons]
    tr = TableRecover()
    table_res = tr(np.array([[np.array(p[0]),np.array(p[1]),np.array(p[2]),np.array(p[3])] for p in polygons]))
    table_html = """<table border="1" cellspacing="0">"""
    for vs in table_res.values():
        table_html+="<tr>"
        for i,v in vs.items():
            table_html+=f"""<td colspan="{v[0]}" rowspan="{v[1]}">{texts[i]}</td>"""
        table_html+="</tr>"
    table_html+="""</table>"""
    return table_html

原图为:https://www.95598.cn/omg-static/99107281818076039603801539578309.jpg

最终识别出来的结果如下:

相关推荐
安逸sgr6 分钟前
1、CycleGAN
pytorch·深度学习·神经网络·生成对抗网络
FL162386312918 分钟前
[数据集][目标检测]俯拍航拍森林火灾检测数据集VOC+YOLO格式6116张2类别
人工智能·深度学习·目标检测
华清远见成都中心18 分钟前
哪些人适合学习人工智能?
人工智能·学习
qq_5503379924 分钟前
研1日记14
人工智能·深度学习·机器学习
i嗑盐の小F31 分钟前
【IEEE&ACM Fellow、CCF组委】第三届人工智能与智能信息处理国际学术会议(AIIIP 2024)
人工智能·深度学习·算法·机器学习·自然语言处理·信号处理
nfgo1 小时前
Apollo自动驾驶项目(二:cyber框架分析)
人工智能·自动驾驶·unix
h177113472051 小时前
基于区块链的相亲交易系统源码解析
大数据·人工智能·安全·系统架构·交友
HPC_fac130520678161 小时前
RTX 4090 系列即将停产,RTX 5090 系列蓄势待发
服务器·人工智能·gpu算力
xuehaisj2 小时前
论文内容分类与检测系统源码分享
人工智能·分类·数据挖掘
大耳朵爱学习2 小时前
大模型预训练的降本增效之路——从信息密度出发
人工智能·深度学习·机器学习·自然语言处理·大模型·llm·大语言模型