数据质量检查:保障 AI 训练数据的可靠性

前言
垃圾进,垃圾出(Garbage In, Garbage Out)。数据质量直接决定了模型性能,数据质量检查是构建高质量 AI 系统的关键环节。
我在多个项目中实践过数据质量检查,今天分享一些方法和经验。
基础质量检查
文本基本检查
python
import re
from typing import List, Dict
from collections import Counter
class TextQualityChecker:
"""文本质量检查器"""
def __init__(self):
self.reports = []
def check_empty(self, text: str) -> Dict:
"""检查空值"""
if not text or len(text.strip()) == 0:
return {
"check": "empty_text",
"passed": False,
"message": "文本为空"
}
return {
"check": "empty_text",
"passed": True,
"message": "文本非空"
}
def check_length(self, text: str, min_len: int = 10, max_len: int = 10000) -> Dict:
"""检查长度"""
length = len(text)
if length < min_len:
return {
"check": "length",
"passed": False,
"message": f"文本过短 ({length} < {min_len})"
}
elif length > max_len:
return {
"check": "length",
"passed": False,
"message": f"文本过长 ({length} > {max_len})"
}
return {
"check": "length",
"passed": True,
"message": f"文本长度正常 ({length})"
}
def check_charset(self, text: str) -> Dict:
"""检查字符集"""
invalid_chars = re.findall(r'[^\x00-\x7F\u4e00-\u9fff\s,。!?、;:""''()【】]', text)
if invalid_chars:
return {
"check": "charset",
"passed": False,
"message": f"包含特殊字符: {set(invalid_chars[:5])}"
}
return {
"check": "charset",
"passed": True,
"message": "字符集正常"
}
def check_repetition(self, text: str, threshold: float = 0.5) -> Dict:
"""检查重复内容"""
words = list(text)
n = len(words)
if n < 10:
return {"check": "repetition", "passed": True, "message": "文本过短"}
# 检查重复词
word_counts = Counter(words)
max_count = max(word_counts.values())
repetition_ratio = max_count / n
if repetition_ratio > threshold:
return {
"check": "repetition",
"passed": False,
"message": f"高重复率: {repetition_ratio:.2f}"
}
return {
"check": "repetition",
"passed": True,
"message": "重复率正常"
}
def check_all(self, text: str) -> List[Dict]:
"""执行所有检查"""
checks = [
self.check_empty,
self.check_length,
self.check_charset,
self.check_repetition
]
results = []
for check in checks:
result = check(text)
results.append(result)
return results
语义质量检查
语义一致性
python
from sentence_transformers import SentenceTransformer
import numpy as np
class SemanticQualityChecker:
"""语义质量检查器"""
def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
self.model = SentenceTransformer(model_name)
def check_consistency(self, text_list: List[str], threshold: float = 0.7) -> Dict:
"""检查文本一致性"""
if len(text_list) < 2:
return {"check": "consistency", "passed": True, "message": "样本不足"}
embeddings = self.model.encode(text_list)
# 计算两两相似度
pairwise_similarities = []
for i in range(len(embeddings)):
for j in range(i+1, len(embeddings)):
sim = np.dot(embeddings[i], embeddings[j]) / (
np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j])
)
pairwise_similarities.append(sim)
avg_similarity = np.mean(pairwise_similarities)
if avg_similarity < threshold:
return {
"check": "consistency",
"passed": False,
"message": f"一致性较低: {avg_similarity:.2f}"
}
return {
"check": "consistency",
"passed": True,
"message": f"一致性正常: {avg_similarity:.2f}"
}
def check_outlier(self, text_list: List[str], threshold: float = 2.0) -> List[int]:
"""检查离群点"""
if len(text_list) < 3:
return []
embeddings = self.model.encode(text_list)
# 计算到质心的距离
centroid = np.mean(embeddings, axis=0)
distances = [np.linalg.norm(emb - centroid) for emb in embeddings]
# 使用 Z-score 检测离群点
mean_dist = np.mean(distances)
std_dist = np.std(distances)
if std_dist == 0:
return []
outliers = []
for i, dist in enumerate(distances):
z_score = (dist - mean_dist) / std_dist
if z_score > threshold:
outliers.append(i)
return outliers
语言质量
python
import jieba
class LanguageQualityChecker:
"""语言质量检查器"""
def __init__(self):
self.stopwords = set(["的", "了", "是", "在", "我", "有", "和", "就"])
def check_vocabulary_richness(self, text: str) -> Dict:
"""检查词汇丰富度"""
words = jieba.lcut(text)
unique_words = set(words)
if len(words) == 0:
return {"check": "vocabulary", "passed": False, "message": "无词汇"}
richness = len(unique_words) / len(words)
if richness < 0.3:
return {
"check": "vocabulary",
"passed": False,
"message": f"词汇丰富度低: {richness:.2f}"
}
return {
"check": "vocabulary",
"passed": True,
"message": f"词汇丰富度正常: {richness:.2f}"
}
def check_sentence_structure(self, text: str) -> Dict:
"""检查句子结构"""
sentences = re.split(r'[。!?]', text)
sentences = [s.strip() for s in sentences if s.strip()]
if len(sentences) == 0:
return {"check": "sentence", "passed": False, "message": "无完整句子"}
avg_length = np.mean([len(s) for s in sentences])
if avg_length < 5:
return {
"check": "sentence",
"passed": False,
"message": f"句子过短: {avg_length:.1f} 字"
}
return {
"check": "sentence",
"passed": True,
"message": f"句子结构正常: {avg_length:.1f} 字"
}
去重与清洗
精确与模糊去重
python
import hashlib
from typing import List, Dict
from dataclasses import dataclass
@dataclass
class DuplicateCheckResult:
"""去重结果"""
unique_data: List[Dict]
duplicates_count: int
duplicate_groups: List[List[int]]
class DuplicateChecker:
"""重复检查器"""
def __init__(self):
self.seen = set()
def exact_dedup(self, data: List[Dict], text_key: str = "text") -> DuplicateCheckResult:
"""精确去重"""
unique = []
duplicates = 0
for item in data:
text_hash = hashlib.md5(item[text_key].encode()).hexdigest()
if text_hash not in self.seen:
self.seen.add(text_hash)
unique.append(item)
else:
duplicates += 1
return DuplicateCheckResult(
unique_data=unique,
duplicates_count=duplicates,
duplicate_groups=[]
)
def fuzzy_dedup(self, data: List[Dict], text_key: str = "text",
threshold: float = 0.95) -> DuplicateCheckResult:
"""模糊去重(基于 SimHash)"""
from simhash import Simhash, SimhashIndex
simhashes = [Simhash(item[text_key]) for item in data]
index = SimhashIndex(simhashes)
duplicate_groups = []
processed = set()
for i, sh in enumerate(simhashes):
if i in processed:
continue
duplicates = index.get_near_dups(sh)
if len(duplicates) > 1:
duplicate_groups.append(duplicates)
processed.update(duplicates)
# 保留每组一个
unique_data = []
for group in duplicate_groups:
unique_data.append(data[group[0]])
return DuplicateCheckResult(
unique_data=unique_data,
duplicates_count=len(data) - len(unique_data),
duplicate_groups=duplicate_groups
)
完整质量检查流水线
python
class DataQualityPipeline:
"""数据质量检查流水线"""
def __init__(self):
self.text_checker = TextQualityChecker()
self.semantic_checker = SemanticQualityChecker()
self.language_checker = LanguageQualityChecker()
self.duplicate_checker = DuplicateChecker()
def run_pipeline(self, data: List[Dict]) -> Dict:
"""运行完整检查流水线"""
report = {
"total": len(data),
"passed": 0,
"failed": 0,
"failures": [],
"duplicates": 0,
"statistics": {}
}
# 1. 去重
dedup_result = self.duplicate_checker.exact_dedup(data)
report["duplicates"] = dedup_result.duplicates_count
data = dedup_result.unique_data
# 2. 逐项检查
for i, item in enumerate(data):
text = item.get("text", "")
checks = self.text_checker.check_all(text)
lang_checks = [
self.language_checker.check_vocabulary_richness(text),
self.language_checker.check_sentence_structure(text)
]
all_checks = checks + lang_checks
passed_all = all(c["passed"] for c in all_checks)
if passed_all:
report["passed"] += 1
else:
report["failed"] += 1
report["failures"].append({
"index": i,
"checks": all_checks
})
# 3. 语义检查(抽样)
if len(data) > 10:
sample = [d["text"] for d in data[:100]]
outliers = self.semantic_checker.check_outlier(sample)
report["statistics"]["semantic_outliers"] = len(outliers)
return report
def clean_data(self, data: List[Dict]) -> List[Dict]:
"""清洗数据"""
report = self.run_pipeline(data)
# 移除失败项
failed_indices = [f["index"] for f in report["failures"]]
cleaned = [
item for i, item in enumerate(data)
if i not in failed_indices
]
return cleaned
总结
数据质量检查要点:
- 基础检查:空值、长度、字符集
- 语义检查:一致性、离群点
- 语言检查:词汇、句子结构
- 去重:精确去重+模糊去重
- 清洗:移除低质量数据
实践建议:
- 建立数据质量标准
- 定期检查和清洗
- 保留清洗前后的数据
- 持续优化检查规则