一、技术方案概述
本文档详细阐述面向OCR输出TXT文件的语义分块实现方案,核心目标是将无结构化、存在乱码/格式不规范的OCR文本,按照语义连贯性 和Token长度约束拆分为高质量文本块(Chunk),同时具备完整的性能耗时统计能力。方案兼顾分块效果与工程实用性,解决了OCR文本分块的核心痛点。
1.1 核心设计原则
- 语义优先:基于句子级语义相似度识别主题边界,保证Chunk内文本语义连贯;
- 长度可控:通过Token计数约束Chunk长度(目标1024Token,最小300Token),适配LLM上下文窗口;
- 鲁棒性强:内置OCR文本清洗逻辑,处理乱码、格式混乱等问题;
- 性能可观测:全维度耗时统计,覆盖整体/阶段/单Chunk三个维度;
- 效率优化:模型缓存、批量计算等手段降低重复开销。
1.2 技术栈选型
| 组件 | 版本/选型 | 核心用途 | 选型优势 |
|---|---|---|---|
| spaCy | zh_core_web_sm/en_core_web_sm | 句子拆分 | 精准的中英文分句能力,支持组件裁剪提速 |
| SentenceTransformer | all-MiniLM-L6-v2 | 语义嵌入 | 轻量级中英通用模型,嵌入效果与效率平衡 |
| HuggingFace Tokenizer | 与嵌入模型同源 | Token计数 | 保证Token计算规则与LLM一致 |
| PyTorch | 最新稳定版 | 模型加速 | 支持GPU推理,降低嵌入计算耗时 |
| Python time模块 | perf_counter | 耗时统计 | 高精度计时(纳秒级),适合性能分析 |
二、核心实现方案
2.1 完整实现代码
python
import spacy
import numpy as np
import warnings
import time
from pathlib import Path
from typing import List, Dict
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer
import torch
import re
# 忽略无关警告
warnings.filterwarnings("ignore")
class OCRTextSemanticChunker:
"""适配OCR TXT文件的语义分块器(带完整语义分块+耗时统计)"""
# 模型缓存(避免重复加载)
_model_cache = {}
_tokenizer_cache = {}
_nlp_cache = {}
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
spacy_model: str = "zh_core_web_sm",
target_chunk_tokens: int = 1024,
min_chunk_tokens: int = 300,
overlap_sentences: int = 2,
similarity_threshold: float = 0.7,
device: str = "auto" # 自动选择cpu/gpu
):
# 1. 设备初始化(优先GPU)
self.device = torch.device("cuda" if torch.cuda.is_available() and device == "auto" else "cpu")
# 2. 模型加载(带缓存)
self.nlp = self._get_spacy_model(spacy_model)
self.embed_model = self._get_embed_model(model_name)
self.tokenizer = self._get_tokenizer(model_name)
# 3. 核心配置参数
self.target_tokens = target_chunk_tokens # Chunk目标Token数
self.min_tokens = min_chunk_tokens # Chunk最小Token数(避免碎片化)
self.overlap_sent = overlap_sentences # Chunk重叠句子数
self.sim_threshold = similarity_threshold # 语义边界相似度阈值
# 4. 优化项:预计算空Token长度+建立Token缓存
self._empty_token_len = len(self.tokenizer.encode("", add_special_tokens=False))
self._token_cache = {} # Token计数缓存:{文本: Token数}
# 5. 耗时统计初始化
self._timing_data = self._init_timing_data()
@classmethod
def _init_timing_data(cls) -> Dict:
"""初始化耗时统计字典"""
return {
"total_time": 0.0, # 总耗时
"parse_time": 0.0, # 解析+清洗耗时
"embedding_time": 0.0, # 语义嵌入耗时
"chunking_time": 0.0, # 分块逻辑耗时
"chunk_times": [] # 单Chunk耗时列表
}
@classmethod
def _get_spacy_model(cls, model_name: str) -> spacy.Language:
"""获取spaCy模型(缓存+组件裁剪)"""
if model_name not in cls._nlp_cache:
# 禁用NER/Tagger等冗余组件,仅保留分句能力
disable_components = ["ner", "tagger"]
if "zh" not in model_name:
disable_components.append("attribute_ruler")
cls._nlp_cache[model_name] = spacy.load(model_name, disable=disable_components)
return cls._nlp_cache[model_name]
@classmethod
def _get_embed_model(cls, model_name: str) -> SentenceTransformer:
"""获取语义嵌入模型(缓存)"""
if model_name not in cls._model_cache:
cls._model_cache[model_name] = SentenceTransformer(model_name)
return cls._model_cache[model_name]
@classmethod
def _get_tokenizer(cls, model_name: str) -> AutoTokenizer:
"""获取Token计算器(缓存)"""
if model_name not in cls._tokenizer_cache:
cls._tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
return cls._tokenizer_cache[model_name]
# ------------------------------ 核心模块1:OCR文本预处理 ------------------------------
def _validate_file(self, file_path: str) -> Path:
"""文件有效性验证"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
if file_path.suffix.lower() != ".txt":
raise ValueError(f"仅支持TXT格式,当前格式: {file_path.suffix}")
return file_path
def _clean_ocr_text(self, text: str) -> str:
"""OCR文本专用清洗逻辑(优化版:批量清理控制字符)"""
# 1. 正则批量匹配:所有ASCII控制字符(0x00-0x1F)+ Unicode替换字符(�)
# [\x00-\x1F] 匹配所有ASCII控制字符,\ufffd 匹配�
text = re.sub(r'[\x00-\x1F\ufffd]', '', text)
# 2. 保留原有的格式化逻辑
text = re.sub(r'\n+', '\n', text)
text = re.sub(r' +', ' ', text)
text = re.sub(r'^ +| +$', '', text, flags=re.MULTILINE)
# 3. 修复中英文混排空格
text = re.sub(r'([\u4e00-\u9fff])([a-zA-Z0-9])', r'\1 \2', text)
text = re.sub(r'([a-zA-Z0-9])([\u4e00-\u9fff])', r'\1 \2', text)
return text.strip()
def _parse_txt_document(self, file_path: str) -> List[Dict]:
"""解析OCR TXT文件(按空行分段落)"""
start_time = time.perf_counter()
# 1. 文件验证+读取(兼容多编码)
file_path = self._validate_file(file_path)
try:
with open(file_path, 'r', encoding='utf-8') as f:
raw_text = f.read()
except UnicodeDecodeError:
with open(file_path, 'r', encoding='gbk', errors='ignore') as f:
raw_text = f.read()
# 2. OCR文本清洗
clean_text = self._clean_ocr_text(raw_text)
if not clean_text:
raise ValueError("TXT文件清洗后无有效内容")
# 3. 按空行分割段落(OCR文本典型格式)
paragraphs = re.split(r'\n{2,}', clean_text)
# 4. 过滤无效段落,生成结构化数据
structured_elements = []
for idx, para_text in enumerate(paragraphs):
para_text = para_text.strip()
if len(para_text) < 5: # 过滤过短无效段落
continue
structured_elements.append({
"type": "Paragraph",
"text": para_text,
"paragraph_id": idx
})
# 5. 记录解析耗时
self._timing_data["parse_time"] = time.perf_counter() - start_time
return structured_elements
# ------------------------------ 核心模块2:语义嵌入与边界检测 ------------------------------
def _split_sentences(self, text: str) -> List[str]:
"""句子拆分(语义分块的基础单元)"""
doc = self.nlp(text)
# 过滤过短句子(至少3字符),保证语义单元完整性
sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) >= 3]
return sentences
def _count_tokens(self, text: str) -> int:
"""Token计数(带缓存优化)"""
if text in self._token_cache:
return self._token_cache[text]
token_count = len(self.tokenizer.encode(text, add_special_tokens=False))
self._token_cache[text] = token_count # 缓存结果
return token_count
def _calculate_batch_embeddings(self, sentences_list: List[List[str]]) -> List[np.ndarray]:
"""批量计算句子语义嵌入(核心优化:减少模型调用次数)"""
start_time = time.perf_counter()
# 1. 扁平化句子列表,记录段落偏移
all_sentences = []
offsets = [0]
for sentences in sentences_list:
all_sentences.extend(sentences)
offsets.append(offsets[-1] + len(sentences))
# 2. 批量生成语义嵌入向量
embeddings = []
if all_sentences:
embeddings = self.embed_model.encode(
all_sentences,
convert_to_tensor=True,
device=self.device,
show_progress_bar=False
).cpu().numpy()
# 3. 拆分回原段落结构
result = []
for i in range(len(offsets)-1):
start, end = offsets[i], offsets[i+1]
result.append(embeddings[start:end] if all_sentences else np.array([]))
# 4. 记录嵌入耗时
self._timing_data["embedding_time"] = time.perf_counter() - start_time
return result
def _find_topic_boundaries(self, sentences: List[str], embeddings: np.ndarray) -> List[int]:
"""语义边界检测(核心逻辑:基于相似度识别主题切换)"""
if len(sentences) <= 1:
return [] # 单句无边界
# 1. 计算相邻句子的余弦相似度
similarities = []
for i in range(len(embeddings) - 1):
sim = util.cos_sim(
torch.tensor(embeddings[i]).to(self.device),
torch.tensor(embeddings[i+1]).to(self.device)
).item()
similarities.append(sim)
# 2. 相似度低于阈值 → 判定为主题边界
boundaries = [i+1 for i, sim in enumerate(similarities) if sim < self.sim_threshold]
# 3. 添加强制边界:避免单个Chunk Token数超限
token_counts = [self._count_tokens(s) for s in sentences]
cumulative_tokens = 0
current_start = 0
for i, token_count in enumerate(token_counts):
cumulative_tokens += token_count
# 超过目标长度80%且句子数>3 → 强制边界
if cumulative_tokens > self.target_tokens * 0.8 and (i - current_start) > 3:
if i+1 not in boundaries:
boundaries.append(i+1)
cumulative_tokens = 0
current_start = i+1
# 4. 去重+排序,保证边界有序
boundaries = sorted(list(set(boundaries)))
return boundaries
# ------------------------------ 核心模块3:Chunk打包与优化 ------------------------------
def _finalize_chunk(self, current_chunk: Dict, chunks: List[Dict], chunk_start_time: float):
"""Chunk最终化(过滤短Chunk+记录单Chunk耗时)"""
# 1. 计算单Chunk耗时
chunk_elapsed = time.perf_counter() - chunk_start_time
self._timing_data["chunk_times"].append(chunk_elapsed)
# 2. 拼接文本+计算Token数
content_text = " ".join(current_chunk["sentences"])
actual_tokens = self._count_tokens(content_text)
# 3. 处理过短Chunk:合并到前一个Chunk(避免碎片化)
if actual_tokens < self.min_tokens and chunks:
prev_chunk = chunks[-1]
merged_text = prev_chunk["text"] + " " + content_text
merged_tokens = self._count_tokens(merged_text)
# 合并后不超过1.5倍目标长度 → 允许合并
if merged_tokens <= self.target_tokens * 1.5:
chunks[-1]["text"] = merged_text
chunks[-1]["tokens"] = merged_tokens
chunks[-1]["time_seconds"] += chunk_elapsed
return
# 4. 生成最终Chunk(含耗时信息)
chunks.append({
"text": content_text,
"tokens": actual_tokens,
"time_seconds": round(chunk_elapsed, 3),
"time_ms": round(chunk_elapsed * 1000, 2)
})
def _add_overlap(self, chunks: List[Dict]) -> List[Dict]:
"""Chunk重叠处理(解决语义断裂问题)"""
if len(chunks) <= 1 or self.overlap_sent <= 0:
return chunks
# 1. 预拆分所有Chunk的句子(仅执行一次)
chunk_sentences = []
for chunk in chunks:
chunk_sentences.append(self._split_sentences(chunk["text"]))
# 2. 为每个Chunk添加前一个Chunk的末尾N句重叠
for i in range(1, len(chunks)):
prev_sent = chunk_sentences[i-1]
current_chunk = chunks[i]
# 取前一个Chunk的最后N句作为重叠内容
overlap_count = min(self.overlap_sent, len(prev_sent))
overlap_sentences = prev_sent[-overlap_count:] if overlap_count > 0 else []
if not overlap_sentences:
continue
# 拼接重叠内容到当前Chunk开头
overlap_text = " ".join(overlap_sentences)
new_text = f"{overlap_text} {current_chunk['text']}"
# 更新Chunk文本和Token数
chunks[i]["text"] = new_text
chunks[i]["tokens"] = self._count_tokens(new_text)
return chunks
def _pack_chunks(self, structured_elements: List[Dict]) -> List[Dict]:
"""核心分块逻辑:语义边界+Token约束打包Chunk"""
start_time = time.perf_counter()
chunks = []
current_chunk = {"sentences": [], "total_tokens": 0}
chunk_start_time = time.perf_counter()
# 1. 预处理:拆分所有段落的句子(仅执行一次)
element_sentences = []
for elem in structured_elements:
element_sentences.append(self._split_sentences(elem["text"]))
# 2. 批量计算语义嵌入
all_embeddings = self._calculate_batch_embeddings(element_sentences)
# 3. 逐段落处理:按语义边界拆分+Token约束打包
for idx, elem in enumerate(structured_elements):
sentences = element_sentences[idx]
if not sentences:
continue
# 获取当前段落的语义嵌入
embeddings = all_embeddings[idx]
# 识别语义边界(添加段落末尾为强制边界)
topic_boundaries = self._find_topic_boundaries(sentences, embeddings)
topic_boundaries.append(len(sentences))
topic_boundaries = sorted(list(set(topic_boundaries)))
# 按边界拆分句子组,打包到Chunk
start_idx = 0
for boundary in topic_boundaries:
sentence_group = sentences[start_idx:boundary]
if not sentence_group:
start_idx = boundary
continue
# 计算当前句子组的Token数
group_text = " ".join(sentence_group)
group_tokens = self._count_tokens(group_text)
# 未超目标长度 → 加入当前Chunk
if current_chunk["total_tokens"] + group_tokens <= self.target_tokens:
current_chunk["sentences"].extend(sentence_group)
current_chunk["total_tokens"] += group_tokens
else:
# 超长度 → 打包当前Chunk,新建Chunk
self._finalize_chunk(current_chunk, chunks, chunk_start_time)
current_chunk = {
"sentences": sentence_group,
"total_tokens": group_tokens
}
chunk_start_time = time.perf_counter()
start_idx = boundary
# 处理最后一个Chunk
if current_chunk["sentences"]:
self._finalize_chunk(current_chunk, chunks, chunk_start_time)
# 添加Chunk重叠
chunks = self._add_overlap(chunks)
# 记录分块逻辑耗时
self._timing_data["chunking_time"] = time.perf_counter() - start_time
return chunks
# ------------------------------ 核心模块4:对外接口 ------------------------------
def get_timing_summary(self) -> Dict:
"""获取耗时统计汇总"""
chunk_times = self._timing_data["chunk_times"]
return {
"total_time_seconds": round(self._timing_data["total_time"], 3),
"total_time_ms": round(self._timing_data["total_time"] * 1000, 1),
"parse_time_seconds": round(self._timing_data["parse_time"], 3),
"embedding_time_seconds": round(self._timing_data["embedding_time"], 3),
"chunking_time_seconds": round(self._timing_data["chunking_time"], 3),
"avg_chunk_time_seconds": round(np.mean(chunk_times) if chunk_times else 0, 3),
"avg_chunk_time_ms": round(np.mean(chunk_times) * 1000 if chunk_times else 0, 1),
"max_chunk_time_ms": round(np.max(chunk_times) * 1000 if chunk_times else 0, 1),
"min_chunk_time_ms": round(np.min(chunk_times) * 1000 if chunk_times else 0, 1),
"chunk_count": len(chunk_times)
}
def chunk_document(self, file_path: str) -> List[Dict]:
"""主接口:OCR TXT文件语义分块"""
# 重置耗时统计
self._timing_data = self._init_timing_data()
total_start_time = time.perf_counter()
try:
# 1. 解析OCR文本
structured_elements = self._parse_txt_document(file_path)
# 2. 核心分块逻辑
chunks = self._pack_chunks(structured_elements)
# 3. 记录总耗时
self._timing_data["total_time"] = time.perf_counter() - total_start_time
return chunks
except Exception as e:
self._timing_data["total_time"] = time.perf_counter() - total_start_time
print(f"分块失败: {str(e)}")
return []
# ------------------------------ 测试示例 ------------------------------
if __name__ == "__main__":
# 初始化分块器
chunker = OCRTextSemanticChunker(
model_name="sentence-transformers/all-MiniLM-L6-v2",
spacy_model="zh_core_web_sm",
target_chunk_tokens=1024,
min_chunk_tokens=300,
overlap_sentences=2,
similarity_threshold=0.7,
device="auto" # 自动选择CPU/GPU
)
# 执行分块(替换为实际OCR TXT路径)
file_path = "ocr_output.txt"
chunks = chunker.chunk_document(file_path)
# 获取耗时统计
timing_summary = chunker.get_timing_summary()
# 输出结果
print("=== 语义分块结果汇总 ===")
print(f"处理文件: {file_path}")
print(f"生成Chunk数: {timing_summary['chunk_count']}")
print(f"总耗时: {timing_summary['total_time_seconds']} 秒")
print(f"平均Chunk耗时: {timing_summary['avg_chunk_time_ms']} 毫秒/个")
# 输出前3个Chunk示例
for i, chunk in enumerate(chunks[:3], 1):
print(f"\n--- Chunk {i}({chunk['tokens']} Token,耗时{chunk['time_ms']}ms)---")
print(f"内容预览: {chunk['text'][:300]}...")
2.2 语义分块核心流程
flowchart TD
A[输入OCR TXT文件] --> B[文件验证+多编码读取]
B --> C[OCR文本清洗:去乱码/格式化空格]
C --> D[按空行分割段落+过滤无效段落]
D --> E[段落拆分句子:spaCy分句]
E --> F[批量计算句子语义嵌入:SentenceTransformer]
F --> G[语义边界检测:余弦相似度<0.7判定主题切换]
G --> H[Token长度约束:添加1024Token强制边界]
H --> I[打包Chunk:合并短Chunk(<300Token)]
I --> J[Chunk重叠:添加2句重叠窗口]
J --> K[输出Chunk:含文本/Token数/耗时]
2.3 核心技术点详解
2.3.1 语义边界检测
语义边界检测是分块效果的核心,实现逻辑:
- 对每个段落的句子生成语义嵌入向量(all-MiniLM-L6-v2);
- 计算相邻句子的余弦相似度,相似度<0.7判定为主题切换(语义边界);
- 结合Token长度约束,对累计Token数超过目标值80%的位置添加强制边界,避免Chunk过大;
- 边界去重排序后,作为句子组的拆分依据。
2.3.2 Chunk质量控制
- 最小长度约束:过滤<300Token的Chunk,合并到前一个Chunk(避免碎片化);
- 最大长度弹性:合并后Chunk允许不超过目标长度的1.5倍(1536Token),平衡完整性与长度约束;
- 重叠窗口:每个Chunk(除第一个)添加前一个Chunk末尾2句重叠,解决语义断裂问题;
- Token计数缓存:避免重复计算相同文本的Token数,提升效率。
2.3.3 OCR文本适配优化
针对OCR文本的特殊性,做了专项优化:
- 兼容UTF-8/GBK等多编码格式,解决OCR文本编码混乱问题;
- 移除OCR常见乱码字符(如
�、控制字符); - 修复中英文混排空格问题,提升分句准确性;
- 按空行分割段落,适配OCR文本的典型格式。
三、性能分析与优化建议
3.1 测试结果分析
| 指标 | 数值 | 分析结论 |
|---|---|---|
| 总耗时 | 29.109秒 | 整体耗时偏高,需优化 |
| 解析+清洗耗时 | 0.03秒 | 预处理逻辑轻量化,无瓶颈 |
| 嵌入计算耗时 | 8.56秒 | 占比29.4%,GPU加速可降低50% |
| 分块逻辑耗时 | 29.079秒 | 占比99.9%,核心性能瓶颈 |
| Chunk总数 | 165个 | 分块粒度符合1024Token目标 |
| 单Chunk平均耗时 | 117.2ms | 均值合理 |
| 单Chunk最长耗时 | 19040.7ms | 存在极端耗时Chunk,需定位优化 |
3.2 性能瓶颈与优化方案
3.2.1 核心瓶颈:分块逻辑耗时过高
原因 :重复的句子拆分、Token计数无缓存、长段落边界检测耗时陡增。
优化方案:
- 句子拆分结果缓存:在
_pack_chunks中一次性拆分所有段落句子并缓存,避免重复调用; - 批量边界检测:对超长段落(>50句)先按20句预拆分,再做语义边界检测;
- 并行计算:对多段落的语义嵌入/边界检测采用多线程并行处理。
3.2.2 次要瓶颈:嵌入计算耗时
优化方案:
- GPU加速:确保PyTorch调用GPU,嵌入耗时可从8.56秒降至4秒以内;
- 模型替换:使用更轻量的
paraphrase-multilingual-MiniLM-L3-v2模型,嵌入速度提升30%; - 批量大小控制:设置Batch Size=64,分批次计算嵌入,避免内存波动。
四、分块效果评估
4.1 效果优势
- 语义连贯:基于语义相似度的边界检测,保证Chunk内文本主题统一;
- 长度适配:1024Token的Chunk长度适配主流LLM(如GPT-3.5)的上下文窗口;
- 鲁棒性强:专门适配OCR文本的乱码/格式问题,分块稳定性高;
- 可观测性:全维度耗时统计,便于性能调优和问题定位。
4.2 适用场景
- OCR识别后的PDF/图片转TXT文件分块;
- 无结构化纯文本文档的语义分块;
- 面向LLM的知识库构建、文本检索等场景。
五、总结
本次实现的OCR TXT语义分块方案,核心价值在于:
- 技术完整性:覆盖文本预处理、语义嵌入、边界检测、Chunk优化全流程,适配OCR文本的特殊性;
- 效果可控:通过语义相似度+Token长度双重约束,保证Chunk的语义连贯性和长度适配性;
- 工程实用:内置模型缓存、耗时统计、异常处理,具备生产环境落地能力;
- 优化空间:当前性能瓶颈集中在分块逻辑,通过缓存、GPU加速、并行计算等手段,可将总耗时降低50%以上。
该方案可直接应用于OCR文本的语义分块场景,也可扩展适配PDF/Word等结构化文档的分块需求。