数据增强方法:提升模型泛化能力的利器

前言
数据是模型的燃料。数据增强通过对现有数据进行变换和扩充,能显著提升模型的泛化能力,降低过拟合风险。
我在多个项目中实践过数据增强,今天分享一些实用方法。
文本数据增强
同义词替换
python
import random
from typing import List
import jieba
class SynonymReplacer:
"""同义词替换"""
def __init__(self, synonym_dict=None):
self.synonym_dict = synonym_dict or self._build_default_dict()
def _build_default_dict(self):
"""构建默认同义词词典"""
return {
"好": ["不错", "棒", "优秀", "出色"],
"快": ["迅速", "快速", "快捷", "飞快"],
"大": ["巨大", "庞大", "宏大", "广大"],
"小": ["微小", "细小", "娇小", "娇小"]
}
def replace(self, text: str, p: float = 0.1) -> str:
"""替换同义词"""
words = jieba.lcut(text)
new_words = []
for word in words:
if random.random() < p and word in self.synonym_dict:
synonyms = self.synonym_dict[word]
new_word = random.choice(synonyms)
new_words.append(new_word)
else:
new_words.append(word)
return "".join(new_words)
回译法
python
from transformers import pipeline
class BackTranslation:
"""回译增强"""
def __init__(self, model_en_zh: str = "Helsinki-NLP/opus-mt-en-zh",
model_zh_en: str = "Helsinki-NLP/opus-mt-zh-en"):
self.translator_en_zh = pipeline("translation", model=model_en_zh)
self.translator_zh_en = pipeline("translation", model=model_zh_en)
def augment(self, text: str) -> str:
"""回译增强"""
# 中 -> 英
english = self.translator_zh_en(text)[0]["translation_text"]
# 英 -> 中
chinese = self.translator_en_zh(english)[0]["translation_text"]
return chinese
随机插入删除
python
class TextRandomEditor:
"""文本随机编辑"""
def __init__(self):
pass
def random_insert(self, text: str, n: int = 2) -> str:
"""随机插入"""
words = list(text)
insert_words = ["的", "了", "啊", "吧", "呢"]
for _ in range(n):
pos = random.randint(0, len(words))
word = random.choice(insert_words)
words.insert(pos, word)
return "".join(words)
def random_delete(self, text: str, p: float = 0.1) -> str:
"""随机删除"""
chars = list(text)
new_chars = []
for char in chars:
if random.random() > p:
new_chars.append(char)
return "".join(new_chars)
def random_swap(self, text: str, n: int = 2) -> str:
"""随机交换"""
chars = list(text)
for _ in range(n):
if len(chars) < 2:
break
pos1 = random.randint(0, len(chars)-1)
pos2 = random.randint(0, len(chars)-1)
chars[pos1], chars[pos2] = chars[pos2], chars[pos1]
return "".join(chars)
语义增强
Embedding 扰动
python
import numpy as np
from sentence_transformers import SentenceTransformer
class EmbeddingPerturbation:
"""Embedding 扰动"""
def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"):
self.model = SentenceTransformer(model_name)
def augment(self, text: str, noise_scale: float = 0.05) -> List[np.ndarray]:
"""在 embedding 空间中增强"""
original_embedding = self.model.encode([text])[0]
# 生成多个扰动版本
augmented_embeddings = []
for _ in range(5):
noise = np.random.normal(0, noise_scale, size=original_embedding.shape)
perturbed = original_embedding + noise
augmented_embeddings.append(perturbed)
return [original_embedding] + augmented_embeddings
上下文扩展
python
from transformers import pipeline
class ContextExtender:
"""上下文扩展"""
def __init__(self, model_name: str = "gpt2"):
self.generator = pipeline("text-generation", model=model_name)
def extend(self, text: str, max_length: int = 200) -> str:
"""扩展文本"""
extended = self.generator(
text,
max_length=max_length,
num_return_sequences=1,
temperature=0.7
)[0]["generated_text"]
return extended
组合增强策略
python
class CombinedAugmenter:
"""组合增强器"""
def __init__(self):
self.synonym_replacer = SynonymReplacer()
self.text_editor = TextRandomEditor()
self.pipeline = []
def add_step(self, augmenter, probability: float = 1.0):
"""添加增强步骤"""
self.pipeline.append((augmenter, probability))
def augment(self, text: str, num_variants: int = 3) -> List[str]:
"""生成多个变体"""
variants = [text]
for _ in range(num_variants):
variant = text
for augmenter, prob in self.pipeline:
if random.random() < prob:
variant = augmenter(variant)
variants.append(variant)
return variants
# 示例
augmenter = CombinedAugmenter()
augmenter.add_step(augmenter.synonym_replacer.replace, 0.8)
augmenter.add_step(augmenter.text_editor.random_swap, 0.3)
texts = augmenter.augment("这个产品很好用", num_variants=5)
数据增强最佳实践
python
class AugmentationPipeline:
"""增强流水线"""
def __init__(self):
self.augmenters = []
def register_augmenter(self, name, augmenter, probability=0.5):
"""注册增强器"""
self.augmenters.append({
"name": name,
"func": augmenter,
"probability": probability
})
def apply(self, text):
"""应用增强"""
result = text
applied = []
for aug in self.augmenters:
if random.random() < aug["probability"]:
try:
result = aug["func"](result)
applied.append(aug["name"])
except Exception:
continue
return result, applied
def generate_synthetic_data(self, original_data, multiplier=3):
"""生成合成数据"""
synthetic = []
for item in original_data:
synthetic.append(item)
for _ in range(multiplier - 1):
augmented, applied = self.apply(item["text"])
synthetic.append({
"text": augmented,
"label": item["label"],
"augmented": True,
"methods": applied
})
return synthetic
总结
数据增强方法要点:
- 多样性:使用多种增强方法
- 语义保持:避免破坏语义
- 适度原则:过度增强反而有害
- 任务适配:根据任务选择方法
- 质量控制:验证增强后的数据
实践建议:
- 从简单的方法开始
- 逐步引入复杂的增强
- 监控增强后的效果
- 保留原始数据作为验证集