目录
[第一部分 原理详解](#第一部分 原理详解)
[1.1 研究背景与问题定义](#1.1 研究背景与问题定义)
[1.1.1 安全对齐的挑战](#1.1.1 安全对齐的挑战)
[1.2 多维度有害性分类体系](#1.2 多维度有害性分类体系)
[1.1.3 类别不平衡的技术障碍](#1.1.3 类别不平衡的技术障碍)
[1.2 数据构建与预处理](#1.2 数据构建与预处理)
[1.2.1 HH-RLHF Red Teaming数据解析](#1.2.1 HH-RLHF Red Teaming数据解析)
[1.2.2 GlaDOS数据集融合策略](#1.2.2 GlaDOS数据集融合策略)
[1.3 模型架构设计](#1.3 模型架构设计)
[1.3.1 RoBERTa编码器](#1.3.1 RoBERTa编码器)
[1.3.2 多任务分类头](#1.3.2 多任务分类头)
[1.4 训练优化策略](#1.4 训练优化策略)
[1.4.1 Focal Loss实现](#1.4.1 Focal Loss实现)
[1.4.2 对抗性训练](#1.4.2 对抗性训练)
[1.5 评估体系](#1.5 评估体系)
[1.5.1 多维度性能指标](#1.5.1 多维度性能指标)
[1.5.2 阈值优化](#1.5.2 阈值优化)
[第二部分 结构化伪代码](#第二部分 结构化伪代码)
[2.1 数据预处理算法](#2.1 数据预处理算法)
[2.2 Focal Loss计算](#2.2 Focal Loss计算)
[2.3 训练循环](#2.3 训练循环)
[2.4 阈值优化算法](#2.4 阈值优化算法)
[第三部分 代码实现](#第三部分 代码实现)
[脚本2:Focal Loss与RoBERTa模型定义](#脚本2:Focal Loss与RoBERTa模型定义)
第一部分 原理详解
1.1 研究背景与问题定义
1.1.1 安全对齐的挑战
大语言模型的安全性对齐要求系统能够精准识别并拒绝用户输入中的有害提示(harmful prompts)。传统的基于规则或关键词匹配的过滤机制在应对语义复杂、隐含性强的对抗性输入时表现出显著的局限性。Perez与Ribeiro在2022年提出的自动化红队测试(automated red teaming)框架揭示了语言模型生成有害内容的潜在路径,强调了构建数据驱动的有害性检测机制的必要性。
Anthropic发布的HH-RLHF(Helpful and Harmless Reinforcement Learning from Human Feedback)数据集提供了高质量的红队测试样本,其中包含了经过人工标注的对抗性对话数据。该数据集的red teaming子集记录了人类标注者设计的旨在诱导模型输出有害内容的提示,以及模型相应的安全/不安全响应。
1.2 多维度有害性分类体系
单一的二元有害/无害分类无法满足精细化安全控制的需求。现代安全系统要求识别有害内容的具体类别以实施差异化的响应策略。基于GlaDOS(Generative Language Assessment Dataset for Offensive Speech)数据集的分类体系与HH-RLHF的标注规范,建立四维分类框架:
-
暴力内容(Violence):包含身体伤害、武器使用、血腥场景描述
-
仇恨言论(Hate Speech):针对受保护群体的歧视、侮辱、煽动性言论
-
性内容(Sexual Content):露骨的性描述、非自愿性内容、儿童性剥削材料
-
自残诱导(Self-Harm):鼓励或提供自杀、自残方法指导的内容
每个维度构成独立的二分类任务,形成多标签分类(multi-label classification)范式。
1.1.3 类别不平衡的技术障碍
安全分类器训练面临严重的类别不平衡问题。在真实对话场景中,有害提示的占比通常低于1%,导致模型倾向于预测负类(安全)。Focal Loss通过重加权机制缓解正负样本不平衡,其数学形式修正了交叉熵损失:
L_{FL} = -\\alpha_t (1 - p_t)\^\\gamma \\log(p_t)
其中 p_t 为模型对真实类别的预测概率,\\gamma \\ge 0 为聚焦参数,\\alpha_t \\in \[0, 1\] 为类别权重系数。当 \\gamma \> 0 时,损失函数降低易分类样本(p_t \\gg 0.5)的权重,迫使模型关注难分类的少数类样本。
1.2 数据构建与预处理
1.2.1 HH-RLHF Red Teaming数据解析
HH-RLHF red teaming数据采用对话树结构存储,每个样本包含:
-
user_prompt:用户输入的潜在有害查询
-
harmful_categories:标注的 harmed categories(多标签)
-
severity_score:人工评估的危害程度评分(1-5)
数据预处理流程包括:
-
对话树扁平化,提取根节点用户提示
-
多标签one-hot编码,将类别标签转换为四维二进制向量 y \\in \\{0, 1\\}\^4
-
文本清洗:去除PII(个人身份信息),标准化Unicode字符
1.2.2 GlaDOS数据集融合策略
GlaDOS数据集提供了细粒度的毒性标注,包含多语言有害内容。融合策略采用分层抽样(stratified sampling):
D_{train} = D_{HH-RLHF}\^{80\\%} \\cup D_{GlaDOS}\^{20\\%}
确保每个类别的样本占比在训练集中保持平衡,避免数据分布偏移。
1.3 模型架构设计
1.3.1 RoBERTa编码器
选用RoBERTa-large(355M参数)作为基础编码器,利用其双向上下文建模能力捕捉提示中的隐含有害语义。编码器输出最后一层隐藏状态 H \\in \\mathbb{R}\^{n \\times d},其中 n 为序列长度,d=1024。
1.3.2 多任务分类头
设计四个独立的线性分类头,每个对应一个有害性维度:
z_i = \\text{Dropout}(\\text{Pooler}(H))
\\hat{y}_i = \\sigma(W_i z_i + b_i), \\quad i \\in \\{1, 2, 3, 4\\}
其中 Pooler 采用[CLS]标记的最终隐藏状态,\\sigma 为sigmoid激活函数,输出各类别的概率分布。
1.4 训练优化策略
1.4.1 Focal Loss实现
针对每个分类头独立计算Focal Loss:
L_i = -\\alpha (1 - \\hat{y}_i)\^\\gamma \\log(\\hat{y}_i) \\cdot y_i - (1 - \\alpha) \\hat{y}_i\^\\gamma \\log(1 - \\hat{y}_i) \\cdot (1 - y_i)
总损失为四个维度的加权和:
L_{total} = \\sum_{i=1}\^{4} \\lambda_i L_i
权重 \\lambda_i 根据各类别的样本频率动态调整。
1.4.2 对抗性训练
引入对抗样本增强鲁棒性。基于Perez & Ribeiro的自动红队方法,使用语言模型生成语义变体:
D_{aug} = \\{\\text{Paraphrase}(x) \\mid x \\in D_{harmful}\\}
通过同义词替换、句法重组生成对抗性提示,强制模型学习语义层面的有害特征而非表层关键词。
1.5 评估体系
1.5.1 多维度性能指标
采用宏平均F1分数(Macro-F1)评估多标签性能:
\\text{Precision}_i = \\frac{TP_i}{TP_i + FP_i}, \\quad \\text{Recall}_i = \\frac{TP_i}{TP_i + FN_i}
F1_i = \\frac{2 \\cdot \\text{Precision}_i \\cdot \\text{Recall}_i}{\\text{Precision}_i + \\text{Recall}_i}
目标要求每个类别 F1_i \> 0.85。
1.5.2 阈值优化
针对每个类别独立优化分类阈值 \\tau_i,而非使用默认的0.5:
\\tau_i\^\* = \\arg \\max_{\\tau} F1_i(\\tau)
通过验证集上的PR曲线分析确定最优决策边界。
第二部分 结构化伪代码
2.1 数据预处理算法
代码段
\begin{algorithm}
\caption{Multi-label Dataset Construction}
\begin{algorithmic}[1]
\Require Raw dialogue trees $T$, label mapping $M$
\Ensure Processed dataset $D = \{(x_j, y_j)\}_{j=1}^N$
\State $D \leftarrow \emptyset$
\For{each tree $T \in \mathcal{T}$}
\State $x \leftarrow \text{ExtractRootPrompt}(T)$
\State $y \leftarrow [0, 0, 0, 0]$
\For{each category $c \in \text{GetLabels}(T)$}
\State $idx \leftarrow M[c]$
\State $y[idx] \leftarrow 1$
\EndFor
\State $D \leftarrow D \cup \{(x, y)\}$
\EndFor
\State \Return \text{StratifiedSplit}(D, \text{ratios}=[0.8, 0.1, 0.1])
\end{algorithmic}
\end{algorithm}
2.2 Focal Loss计算
代码段
\begin{algorithm}
\caption{Multi-label Focal Loss Computation}
\begin{algorithmic}[1]
\Require Predictions $\hat{Y} \in \mathbb{R}^{N \times 4}$, Targets $Y \in \{0, 1\}^{N \times 4}$, $\alpha = 0.25, \gamma = 2.0$
\Ensure Scalar loss $L$
\State $L \leftarrow 0$
\For{$i \leftarrow 1$ to 4}
\State $\hat{y}_i \leftarrow \hat{Y}[:, i]$
\State $y_i \leftarrow Y[:, i]$
\State $p_t \leftarrow y_i \odot \hat{y}_i + (1 - y_i) \odot (1 - \hat{y}_i)$
\State $\alpha_t \leftarrow \alpha \cdot y_i + (1 - \alpha) \cdot (1 - y_i)$
\State $L_i \leftarrow -\sum_{j=1}^N \alpha_{t, j} (1 - p_{t, j})^\gamma \log(p_{t, j})$
\State $L \leftarrow L + \lambda_i \cdot L_i$
\EndFor
\State \Return $L$
\end{algorithmic}
\end{algorithm}
2.3 训练循环
代码段
\begin{algorithm}
\caption{Harmfulness Classifier Training}
\begin{algorithmic}[1]
\Require Model $M_\theta$, Train set $D_{train}$, Val set $D_{val}$, Epochs $E$, Batch size $B$
\Ensure Trained parameters $\theta^*$
\For{$e \leftarrow 1$ to $E$}
\State $\mathcal{B} \leftarrow \text{Batchify}(D_{train}, B)$
\For{each batch $(X, Y) \in \mathcal{B}$}
\State $\hat{Y} \leftarrow \sigma(M_\theta(X))$
\State $L \leftarrow \text{FocalLoss}(\hat{Y}, Y)$
\State $g \leftarrow \nabla_\theta L$
\State $\theta \leftarrow \text{AdamW}(\theta, g, lr=2e-5)$
\EndFor
\State $F1_{val} \leftarrow \text{Evaluate}(M_\theta, D_{val})$
\If{$F1_{val}$ improved}
\State \text{SaveCheckpoint}(\theta)
\EndIf
\EndFor
\State \Return \text{LoadBestCheckpoint}()
\end{algorithmic}
\end{algorithm}
2.4 阈值优化算法
代码段
\begin{algorithm}
\caption{Per-class Threshold Optimization}
\begin{algorithmic}[1]
\Require Validation predictions $\hat{Y}_{val}$, true labels $Y_{val}$, threshold grid $\mathcal{T} = \{0.1, 0.2, \dots, 0.9\}$
\Ensure Optimal thresholds $\tau^* \in \mathbb{R}^4$
\State $\tau^* \leftarrow [0.5, 0.5, 0.5, 0.5]$
\For{$i \leftarrow 1$ to 4}
\State $BestF1 \leftarrow 0$
\For{each $\tau \in \mathcal{T}$}
\State $\hat{y}_{bin} \leftarrow \mathbb{I}(\hat{Y}_{val}[:, i] \ge \tau)$
\State $F1 \leftarrow \text{ComputeF1}(\hat{y}_{bin}, Y_{val}[:, i])$
\If{$F1 > BestF1$}
\State $BestF1 \leftarrow F1$
\State $\tau^*[i] \leftarrow \tau$
\EndIf
\EndFor
\EndFor
\State \Return $\tau^*$
\end{algorithmic}
\end{algorithm}
第三部分 代码实现
脚本1:数据预处理与多标签构建
内容描述:本脚本实现HH-RLHF与GlaDOS数据集的加载、清洗、多标签编码及分层抽样。生成训练、验证、测试集JSON文件,附带类别分布可视化。
使用方式 :python 01_data_preparation.py --hh_path ./hh-rlhf --glados_path ./glados --output ./processed_data
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 1: Data Preprocessing and Multi-label Construction
基于HH-RLHF Red Teaming数据与GlaDOS数据集的多维度安全分类数据准备
Usage:
python 01_data_preparation.py \
--hh_path ./hh-rlhf/red_teaming \
--glados_path ./glados.csv \
--output ./processed_data \
--seed 42
"""
import os
import json
import argparse
import random
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
# 设置随机种子
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
# 类别映射定义
CATEGORY_MAPPING = {
'violence': 0,
'hate_speech': 1,
'sexual_content': 2,
'self_harm': 3
}
REVERSE_MAPPING = {v: k for k, v in CATEGORY_MAPPING.items()}
def parse_hh_rlhf_red_teaming(file_path: str) -> List[Dict]:
"""
解析HH-RLHF red teaming数据
预期格式:JSON Lines,每行包含对话树和有害性标签
"""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
record = json.loads(line.strip())
# 提取根节点用户提示
if 'conversation' in record and len(record['conversation']) > 0:
root_prompt = record['conversation'][0]['text']
else:
root_prompt = record.get('prompt', '')
# 提取多标签
labels = record.get('harmful_categories', [])
label_vector = [0, 0, 0, 0]
for cat in labels:
if cat in CATEGORY_MAPPING:
label_vector[CATEGORY_MAPPING[cat]] = 1
data.append({
'text': root_prompt,
'labels': label_vector,
'source': 'hh_rlhf',
'severity': record.get('severity_score', 1)
})
return data
def parse_glados_data(file_path: str) -> List[Dict]:
"""
解析GlaDOS数据集
预期格式:CSV,包含text和category列
"""
df = pd.read_csv(file_path)
data = []
for _, row in df.iterrows():
label_vector = [0, 0, 0, 0]
categories = row['category'].split(',') if isinstance(row['category'], str) else [row['category']]
for cat in categories:
cat = cat.strip().lower()
if cat in CATEGORY_MAPPING:
label_vector[CATEGORY_MAPPING[cat]] = 1
data.append({
'text': str(row['text']),
'labels': label_vector,
'source': 'glados',
'severity': row.get('severity', 1)
})
return data
def clean_text(text: str) -> str:
"""
文本清洗:移除PII、标准化Unicode、限制长度
"""
# 基础清洗
text = text.strip()
# 限制长度(RoBERTa最大512 tokens,预留空间)
text = text[:1000]
return text
def stratified_multi_label_split(
data: List[Dict],
test_size: float = 0.2,
val_size: float = 0.1
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
分层抽样确保每个类别在训练/验证/测试集中均有代表
使用迭代分层法处理多标签数据
"""
# 为每个样本创建类别签名用于分层
signatures = [tuple(d['labels']) for d in data]
signature_counts = Counter(signatures)
# 将罕见组合(<3次)合并到"其他"类别以确保可分性
rare_sigs = {sig for sig, count in signature_counts.items() if count < 3}
# 创建分层用的标签(主导类别索引)
stratify_labels = []
for d in data:
sig = tuple(d['labels'])
if sig in rare_sigs:
# 使用第一个存在的类别作为分层依据
labels = d['labels']
dominant = next((i for i, v in enumerate(labels) if v == 1), 0)
stratify_labels.append(dominant)
else:
stratify_labels.append(CATEGORY_MAPPING['violence'] if sig[0] else
CATEGORY_MAPPING['hate_speech'] if sig[1] else
CATEGORY_MAPPING['sexual_content'] if sig[2] else
CATEGORY_MAPPING['self_harm'])
# 首先划分出测试集
train_val_idx, test_idx = train_test_split(
range(len(data)),
test_size=test_size,
stratify=stratify_labels,
random_state=42
)
# 从剩余数据中划分验证集
train_val_labels = [stratify_labels[i] for i in train_val_idx]
relabel_map = {old_idx: new_idx for new_idx, old_idx in enumerate(train_val_idx)}
new_labels = [train_val_labels[relabel_map[i]] for i in train_val_idx]
train_idx, val_idx = train_test_split(
train_val_idx,
test_size=val_size/(1-test_size),
stratify=new_labels,
random_state=42
)
train_data = [data[i] for i in train_idx]
val_data = [data[i] for i in val_idx]
test_data = [data[i] for i in test_idx]
return train_data, val_data, test_data
def visualize_class_distribution(data: List[Dict], save_path: str):
"""
可视化类别分布与多标签共现矩阵
"""
labels_matrix = np.array([d['labels'] for d in data])
category_names = list(CATEGORY_MAPPING.keys())
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
# 1. 各类别正样本计数
ax1 = axes[0, 0]
counts = labels_matrix.sum(axis=0)
colors = ['#e74c3c', '#e67e22', '#f39c12', '#c0392b']
bars = ax1.bar(range(4), counts, color=colors, alpha=0.8, edgecolor='black')
ax1.set_xticks(range(4))
ax1.set_xticklabels(category_names, rotation=15, ha='right')
ax1.set_ylabel('Positive Samples Count')
ax1.set_title('Per-Class Distribution')
for bar, count in zip(bars, counts):
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{int(count)}', ha='center', va='bottom')
# 2. 标签共现热图
ax2 = axes[0, 1]
cooccurrence = labels_matrix.T @ labels_matrix
im = ax2.imshow(cooccurrence, cmap='YlOrRd', aspect='auto')
ax2.set_xticks(range(4))
ax2.set_yticks(range(4))
ax2.set_xticklabels(category_names, rotation=45, ha='right')
ax2.set_yticklabels(category_names)
ax2.set_title('Label Co-occurrence Matrix')
for i in range(4):
for j in range(4):
text = ax2.text(j, i, int(cooccurrence[i, j]),
ha="center", va="center", color="black")
plt.colorbar(im, ax=ax2)
# 3. 多标签组合频率
ax3 = axes[1, 0]
label_combos = [tuple(row) for row in labels_matrix]
combo_counts = Counter(label_combos)
top_combos = combo_counts.most_common(8)
combo_labels = ['+'.join([category_names[i] for i, v in enumerate(combo) if v])
if any(combo) else 'clean' for combo, _ in top_combos]
combo_values = [count for _, count in top_combos]
y_pos = np.arange(len(combo_labels))
ax3.barh(y_pos, combo_values, color='steelblue', alpha=0.7)
ax3.set_yticks(y_pos)
ax3.set_yticklabels(combo_labels, fontsize=9)
ax3.set_xlabel('Frequency')
ax3.set_title('Top Label Combinations')
ax3.invert_yaxis()
# 4. 样本长度分布
ax4 = axes[1, 1]
lengths = [len(d['text'].split()) for d in data]
ax4.hist(lengths, bins=50, color='green', alpha=0.6, edgecolor='black')
ax4.axvline(np.mean(lengths), color='red', linestyle='--',
label=f'Mean: {np.mean(lengths):.1f}')
ax4.set_xlabel('Text Length (tokens)')
ax4.set_ylabel('Frequency')
ax4.set_title('Text Length Distribution')
ax4.legend()
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"[INFO] Distribution visualization saved to {save_path}")
def main():
parser = argparse.ArgumentParser(description='Prepare harmfulness classification dataset')
parser.add_argument('--hh_path', type=str, required=True, help='Path to HH-RLHF red teaming data')
parser.add_argument('--glados_path', type=str, required=True, help='Path to GlaDOS CSV')
parser.add_argument('--output', type=str, default='./processed_data', help='Output directory')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
args = parser.parse_args()
set_seed(args.seed)
os.makedirs(args.output, exist_ok=True)
print("[INFO] Loading HH-RLHF red teaming data...")
hh_data = parse_hh_rlhf_red_teaming(args.hh_path)
print(f"[INFO] Loaded {len(hh_data)} samples from HH-RLHF")
print("[INFO] Loading GlaDOS data...")
glados_data = parse_glados_data(args.glados_path)
print(f"[INFO] Loaded {len(glados_data)} samples from GlaDOS")
# 数据融合:HH-RLHF占80%,GlaDOS占20%
np.random.shuffle(hh_data)
np.random.shuffle(glados_data)
hh_keep = int(len(hh_data) * 0.8)
glados_keep = int(len(glados_data) * 0.2)
combined_data = hh_data[:hh_keep] + glados_data[:glados_keep]
np.random.shuffle(combined_data)
# 文本清洗
for item in combined_data:
item['text'] = clean_text(item['text'])
print(f"[INFO] Combined dataset size: {len(combined_data)}")
# 划分数据集
train, val, test = stratified_multi_label_split(combined_data)
# 保存
splits = {'train': train, 'validation': val, 'test': test}
for split_name, split_data in splits.items():
output_file = os.path.join(args.output, f'{split_name}.json')
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(split_data, f, ensure_ascii=False, indent=2)
print(f"[INFO] {split_name}: {len(split_data)} samples -> {output_file}")
# 可视化训练集分布
viz_path = os.path.join(args.output, 'class_distribution.png')
visualize_class_distribution(train, viz_path)
# 保存类别映射
with open(os.path.join(args.output, 'category_mapping.json'), 'w') as f:
json.dump(CATEGORY_MAPPING, f, indent=2)
print("[INFO] Data preparation completed successfully")
if __name__ == '__main__':
main()
脚本2:Focal Loss与RoBERTa模型定义
内容描述:实现多标签Focal Loss、RoBERTa-based多维度分类器架构、层次化注意力掩码机制。
使用方式 :python 02_model_definition.py(作为模块导入或在训练脚本中使用)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 2: Focal Loss and Model Architecture
多标签有害性分类器的PyTorch实现,包含Focal Loss与RoBERTa架构
Usage:
from 02_model_definition import HarmfulnessClassifier, FocalLoss
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import RobertaModel, RobertaTokenizer, RobertaConfig
from typing import Dict, List, Optional, Tuple
import numpy as np
class FocalLoss(nn.Module):
"""
多标签Focal Loss实现
针对类别不平衡问题,降低易分类样本权重,聚焦难分类样本
数学形式:FL = -α(1-p)^γ log(p)
"""
def __init__(
self,
alpha: Optional[torch.Tensor] = None,
gamma: float = 2.0,
reduction: str = 'mean',
num_classes: int = 4
):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.reduction = reduction
self.num_classes = num_classes
# 初始化类别权重(处理不平衡)
if alpha is None:
# 默认权重可通过后续分析调整
self.alpha = torch.ones(num_classes) * 0.25
else:
self.alpha = alpha
# BCE损失基础
self.bce = nn.BCELoss(reduction='none')
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor
) -> torch.Tensor:
"""
Args:
inputs: 模型预测概率 [batch_size, num_classes],sigmoid后
targets: 真实标签 [batch_size, num_classes],0或1
"""
# 确保alpha与输入同设备
if self.alpha.device != inputs.device:
self.alpha = self.alpha.to(inputs.device)
# 计算基础BCE损失
bce_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
# 计算pt:预测概率(对正类取p,对负类取1-p)
pt = torch.where(targets == 1, inputs, 1 - inputs)
# 计算αt:正类取α,负类取1-α
alpha_t = torch.where(
targets == 1,
self.alpha.unsqueeze(0).expand_as(targets),
(1 - self.alpha).unsqueeze(0).expand_as(targets)
)
# Focal Loss权重调制
focal_weight = alpha_t * (1.0 - pt) ** self.gamma
# 加权损失
loss = focal_weight * bce_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
class HarmfulnessClassifier(nn.Module):
"""
基于RoBERTa的多维度有害性分类器
架构特点:
1. RoBERTa编码器提取上下文特征
2. 独立的多任务分类头(每类一个)
3. 支持动态阈值调整
4. 类别特定的Dropout率
"""
def __init__(
self,
model_name: str = 'roberta-large',
num_classes: int = 4,
dropout_rate: float = 0.1,
class_dropouts: Optional[List[float]] = None,
freeze_layers: int = 6 # 冻结底层6层,仅微调上层
):
super(HarmfulnessClassifier, self).__init__()
self.config = RobertaConfig.from_pretrained(model_name)
self.roberta = RobertaModel.from_pretrained(model_name)
# 选择性冻结:冻结embedding和底层transformer
if freeze_layers > 0:
# 冻结embeddings
for param in self.roberta.embeddings.parameters():
param.requires_grad = False
# 冻结前freeze_layers层
for layer in self.roberta.encoder.layer[:freeze_layers]:
for param in layer.parameters():
param.requires_grad = False
self.hidden_size = self.config.hidden_size
self.num_classes = num_classes
# 类别特定的Dropout(对不同风险类别可设置不同正则化强度)
if class_dropouts is None:
class_dropouts = [dropout_rate] * num_classes
self.dropouts = nn.ModuleList([
nn.Dropout(p) for p in class_dropouts
])
# 独立分类头:每个类别一个线性层
# 动机:不同类别关注不同语义特征
self.classifiers = nn.ModuleList([
nn.Linear(self.hidden_size, 1) for _ in range(num_classes)
])
# 初始化分类头权重
for classifier in self.classifiers:
nn.init.xavier_uniform_(classifier.weight)
nn.init.zeros_(classifier.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
return_hidden: bool = False
) -> Dict[str, torch.Tensor]:
"""
前向传播
Returns:
dict包含:
- logits: [batch_size, num_classes]
- probs: sigmoid后的概率
- hidden_states: 可选的隐藏层输出
"""
# RoBERTa编码
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# 提取[CLS]标记表示
pooled_output = outputs.last_hidden_state[:, 0, :] # [CLS] token
# 每个类别独立预测
logits = []
for i, (dropout, classifier) in enumerate(zip(self.dropouts, self.classifiers)):
dropped = dropout(pooled_output)
logit = classifier(dropped)
logits.append(logit)
logits = torch.cat(logits, dim=1) # [batch_size, num_classes]
probs = torch.sigmoid(logits)
result = {
'logits': logits,
'probs': probs
}
if return_hidden:
result['hidden_states'] = outputs.last_hidden_state
return result
def predict_with_threshold(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
thresholds: Optional[List[float]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
使用优化后的阈值进行预测
Args:
thresholds: 各类别的决策阈值,默认0.5
"""
if thresholds is None:
thresholds = [0.5] * self.num_classes
outputs = self.forward(input_ids, attention_mask)
probs = outputs['probs']
# 应用类别特定阈值
thresholds_tensor = torch.tensor(thresholds).to(probs.device)
predictions = (probs >= thresholds_tensor).float()
return predictions, probs
class AsymmetricLoss(nn.Module):
"""
非对称损失:对正类(有害)和负类(安全)采用不同gamma参数
正类使用较高gamma聚焦难例,负类使用较低gamma避免过度抑制
"""
def __init__(
self,
gamma_pos: float = 1.0,
gamma_neg: float = 4.0,
clip: float = 0.05,
eps: float = 1e-8,
disable_torch_grad_focal_loss: bool = False
):
super(AsymmetricLoss, self).__init__()
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Args:
x: 预测概率 (sigmoid后)
y: 目标标签
"""
# 概率截断防止数值不稳定
x = torch.clamp(x, min=self.eps, max=1-self.eps)
# 计算正类和负类的交叉熵
los_pos = y * torch.log(x)
los_neg = (1 - y) * torch.log(1 - x)
# 非对称聚焦
if self.gamma_neg == 0 and self.gamma_pos == 0:
loss = -los_pos - los_neg
else:
# 负类(安全样本)使用高gamma降低权重
loss = -torch.pow(1 - x, self.gamma_pos) * los_pos - \
torch.pow(x, self.gamma_neg) * los_neg
return loss.mean()
def get_class_weights(dataset_samples: List[Dict]) -> torch.Tensor:
"""
从数据集统计计算类别权重(逆频率)
用于Focal Loss的alpha参数初始化
"""
labels = np.array([d['labels'] for d in dataset_samples])
pos_counts = labels.sum(axis=0)
total = len(labels)
# 逆频率归一化
weights = total / (len(pos_counts) * pos_counts + 1e-5)
weights = weights / weights.sum() * len(pos_counts) # 归一化
return torch.tensor(weights, dtype=torch.float32)
def test_model():
"""单元测试:验证模型前向传播与损失计算"""
print("[TEST] Initializing model...")
model = HarmfulnessClassifier(
model_name='roberta-base', # 测试使用base版本
num_classes=4,
dropout_rate=0.1
)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# 模拟输入
texts = [
"How to make a bomb?", # 暴力
"I love this product", # 安全
"Kill all people", # 暴力+仇恨
]
encodings = tokenizer(
texts,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
print("[TEST] Forward pass...")
with torch.no_grad():
outputs = model(
input_ids=encodings['input_ids'],
attention_mask=encodings['attention_mask']
)
print(f"[TEST] Output shape: {outputs['probs'].shape}")
print(f"[TEST] Probabilities:\n{outputs['probs']}")
# 测试Focal Loss
labels = torch.tensor([
[1, 0, 0, 0], # 暴力
[0, 0, 0, 0], # 安全
[1, 1, 0, 0], # 暴力+仇恨
], dtype=torch.float32)
criterion = FocalLoss(alpha=torch.tensor([0.3, 0.25, 0.25, 0.2]), gamma=2.0)
loss = criterion(outputs['probs'], labels)
print(f"[TEST] Focal Loss: {loss.item():.4f}")
print("[TEST] All tests passed!")
if __name__ == '__main__':
test_model()
脚本3:模型训练与验证
内容描述:实现完整的训练循环,包含动态学习率调度、早停机制、验证集F1监控、检查点保存与TensorBoard日志。
使用方式 :python 03_training.py --data_dir ./processed_data --output_dir ./models --epochs 10 --batch_size 16
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 3: Model Training and Validation
有害性分类器完整训练流程,支持多GPU与混合精度训练
Usage:
python 03_training.py \
--data_dir ./processed_data \
--output_dir ./models \
--model_name roberta-large \
--epochs 10 \
--batch_size 16 \
--lr 2e-5 \
--gamma 2.0 \
--use_amp
"""
import os
import json
import argparse
import logging
from datetime import datetime
from typing import Dict, List, Tuple
import numpy as np
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import RobertaTokenizer, get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
# 导入本地模块
from 02_model_definition import HarmfulnessClassifier, FocalLoss, get_class_weights
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('training.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class HarmfulnessDataset(Dataset):
"""有害性分类数据集封装"""
def __init__(
self,
data: List[Dict],
tokenizer: RobertaTokenizer,
max_length: int = 512
):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
item = self.data[idx]
text = item['text']
labels = torch.tensor(item['labels'], dtype=torch.float32)
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': labels,
'text': text # 保留用于错误分析
}
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""批处理聚合函数"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
labels = torch.stack([item['labels'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'texts': [item['text'] for item in batch]
}
class EarlyStopping:
"""早停机制防止过拟合"""
def __init__(self, patience: int = 3, min_delta: float = 0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, val_score: float) -> bool:
if self.best_score is None:
self.best_score = val_score
elif val_score < self.best_score - self.min_delta:
self.best_score = val_score
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def compute_metrics(
predictions: np.ndarray,
labels: np.ndarray,
threshold: float = 0.5
) -> Dict[str, float]:
"""
计算多标签分类指标
返回每个类别的Precision、Recall、F1及宏平均
"""
binary_preds = (predictions >= threshold).astype(int)
metrics = {}
category_names = ['violence', 'hate_speech', 'sexual_content', 'self_harm']
# 每个类别的指标
for i, name in enumerate(category_names):
metrics[f'{name}_precision'] = precision_score(labels[:, i], binary_preds[:, i], zero_division=0)
metrics[f'{name}_recall'] = recall_score(labels[:, i], binary_preds[:, i], zero_division=0)
metrics[f'{name}_f1'] = f1_score(labels[:, i], binary_preds[:, i], zero_division=0)
# 宏平均
metrics['macro_f1'] = np.mean([metrics[f'{name}_f1'] for name in category_names])
metrics['macro_precision'] = np.mean([metrics[f'{name}_precision'] for name in category_names])
metrics['macro_recall'] = np.mean([metrics[f'{name}_recall'] for name in category_names])
return metrics
def train_epoch(
model: nn.Module,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
device: torch.device,
scheduler: torch.optim.lr_scheduler._LRScheduler,
scaler: GradScaler,
use_amp: bool,
epoch: int,
writer: SummaryWriter,
global_step: int
) -> Tuple[float, int]:
"""单个训练周期"""
model.train()
total_loss = 0
num_batches = len(dataloader)
progress = tqdm(dataloader, desc=f'Epoch {epoch} [Train]')
for batch in progress:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
optimizer.zero_grad()
# 混合精度前向
with autocast(enabled=use_amp):
outputs = model(input_ids, attention_mask)
loss = criterion(outputs['probs'], labels)
# 反向传播
if use_amp:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
total_loss += loss.item()
global_step += 1
# 记录学习率
writer.add_scalar('Train/Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
writer.add_scalar('Train/Loss_Step', loss.item(), global_step)
progress.set_postfix({'loss': f'{loss.item():.4f}'})
avg_loss = total_loss / num_batches
return avg_loss, global_step
def validate(
model: nn.Module,
dataloader: DataLoader,
criterion: nn.Module,
device: torch.device,
epoch: int,
writer: SummaryWriter
) -> Tuple[float, Dict[str, float], np.ndarray, np.ndarray]:
"""验证集评估"""
model.eval()
total_loss = 0
all_preds = []
all_labels = []
with torch.no_grad():
progress = tqdm(dataloader, desc=f'Epoch {epoch} [Val]')
for batch in progress:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
loss = criterion(outputs['probs'], labels)
total_loss += loss.item()
all_preds.append(outputs['probs'].cpu().numpy())
all_labels.append(labels.cpu().numpy())
progress.set_postfix({'loss': f'{loss.item():.4f}'})
all_preds = np.concatenate(all_preds, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
avg_loss = total_loss / len(dataloader)
metrics = compute_metrics(all_preds, all_labels)
# 记录到TensorBoard
writer.add_scalar('Val/Loss', avg_loss, epoch)
writer.add_scalar('Val/Macro_F1', metrics['macro_f1'], epoch)
for name in ['violence', 'hate_speech', 'sexual_content', 'self_harm']:
writer.add_scalar(f'Val/F1/{name}', metrics[f'{name}_f1'], epoch)
return avg_loss, metrics, all_preds, all_labels
def find_optimal_thresholds(
val_probs: np.ndarray,
val_labels: np.ndarray,
num_steps: int = 100
) -> List[float]:
"""
在验证集上搜索各类别的最优阈值
最大化每个类别的F1分数
"""
thresholds = []
category_names = ['violence', 'hate_speech', 'sexual_content', 'self_harm']
for i in range(4):
best_threshold = 0.5
best_f1 = 0
# 在[0.1, 0.9]范围内搜索
for threshold in np.linspace(0.1, 0.9, num_steps):
preds = (val_probs[:, i] >= threshold).astype(int)
f1 = f1_score(val_labels[:, i], preds, zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_threshold = threshold
thresholds.append(best_threshold)
logger.info(f"[Threshold] {category_names[i]}: {best_threshold:.3f} (F1: {best_f1:.4f})")
return thresholds
def save_checkpoint(
model: nn.Module,
tokenizer: RobertaTokenizer,
optimizer: torch.optim.Optimizer,
epoch: int,
metrics: Dict,
thresholds: List[float],
output_dir: str,
is_best: bool = False
):
"""保存模型检查点"""
checkpoint_dir = os.path.join(output_dir, f'checkpoint-epoch-{epoch}')
if is_best:
checkpoint_dir = os.path.join(output_dir, 'best_model')
os.makedirs(checkpoint_dir, exist_ok=True)
# 保存模型
model.save_pretrained(checkpoint_dir)
tokenizer.save_pretrained(checkpoint_dir)
# 保存训练状态
state = {
'epoch': epoch,
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'optimal_thresholds': thresholds
}
torch.save(state, os.path.join(checkpoint_dir, 'training_state.pt'))
# 保存配置
config = {
'num_classes': 4,
'categories': ['violence', 'hate_speech', 'sexual_content', 'self_harm'],
'optimal_thresholds': thresholds,
'metrics': metrics
}
with open(os.path.join(checkpoint_dir, 'classifier_config.json'), 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"[SAVE] Checkpoint saved to {checkpoint_dir}")
def main():
parser = argparse.ArgumentParser(description='Train harmfulness classifier')
parser.add_argument('--data_dir', type=str, required=True, help='Processed data directory')
parser.add_argument('--output_dir', type=str, default='./models', help='Model output directory')
parser.add_argument('--model_name', type=str, default='roberta-large', help='Pretrained model name')
parser.add_argument('--epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
parser.add_argument('--lr', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--warmup_ratio', type=float, default=0.1, help='Warmup ratio')
parser.add_argument('--gamma', type=float, default=2.0, help='Focal loss gamma')
parser.add_argument('--alpha', type=str, default='auto', help='Focal loss alpha (auto or comma-separated)')
parser.add_argument('--max_length', type=int, default=512, help='Max sequence length')
parser.add_argument('--use_amp', action='store_true', help='Use mixed precision training')
parser.add_argument('--seed', type=int, default=42, help='Random seed')
args = parser.parse_args()
# 设置随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"[INFO] Using device: {device}")
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_dir = os.path.join(args.output_dir, f'run_{timestamp}')
os.makedirs(run_dir, exist_ok=True)
# TensorBoard
writer = SummaryWriter(os.path.join(run_dir, 'logs'))
# 加载数据
logger.info("[INFO] Loading datasets...")
with open(os.path.join(args.data_dir, 'train.json'), 'r') as f:
train_data = json.load(f)
with open(os.path.join(args.data_dir, 'validation.json'), 'r') as f:
val_data = json.load(f)
# 初始化tokenizer和模型
logger.info(f"[INFO] Loading model: {args.model_name}")
tokenizer = RobertaTokenizer.from_pretrained(args.model_name)
model = HarmfulnessClassifier(
model_name=args.model_name,
num_classes=4,
dropout_rate=0.1,
freeze_layers=6 # 冻结前6层
)
model.to(device)
# 计算类别权重(用于Focal Loss)
if args.alpha == 'auto':
alpha_weights = get_class_weights(train_data)
logger.info(f"[INFO] Computed class weights: {alpha_weights.tolist()}")
else:
alpha_weights = torch.tensor([float(x) for x in args.alpha.split(',')])
criterion = FocalLoss(alpha=alpha_weights, gamma=args.gamma)
criterion.to(device)
# 数据集与加载器
train_dataset = HarmfulnessDataset(train_data, tokenizer, args.max_length)
val_dataset = HarmfulnessDataset(val_data, tokenizer, args.max_length)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size * 2,
shuffle=False,
collate_fn=collate_fn,
num_workers=4
)
# 优化器与调度器
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
total_steps = len(train_loader) * args.epochs
warmup_steps = int(total_steps * args.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
# 混合精度训练
scaler = GradScaler() if args.use_amp else None
# 早停机制
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
best_macro_f1 = 0
# 训练循环
global_step = 0
for epoch in range(1, args.epochs + 1):
logger.info(f"\n[INFO] Starting Epoch {epoch}/{args.epochs}")
# 训练
train_loss, global_step = train_epoch(
model, train_loader, optimizer, criterion, device,
scheduler, scaler, args.use_amp, epoch, writer, global_step
)
# 验证
val_loss, metrics, val_probs, val_labels = validate(
model, val_loader, criterion, device, epoch, writer
)
# 记录epoch级指标
logger.info(f"[Epoch {epoch}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
logger.info(f"[Epoch {epoch}] Macro F1: {metrics['macro_f1']:.4f}")
for name in ['violence', 'hate_speech', 'sexual_content', 'self_harm']:
logger.info(f" - {name}: F1={metrics[f'{name}_f1']:.4f}, "
f"P={metrics[f'{name}_precision']:.4f}, "
f"R={metrics[f'{name}_recall']:.4f}")
# 阈值优化(仅在最后两个epoch或最佳模型时)
if epoch >= args.epochs - 2 or metrics['macro_f1'] > best_macro_f1:
thresholds = find_optimal_thresholds(val_probs, val_labels)
else:
thresholds = [0.5, 0.5, 0.5, 0.5]
# 保存最佳模型
is_best = metrics['macro_f1'] > best_macro_f1
if is_best:
best_macro_f1 = metrics['macro_f1']
save_checkpoint(
model, tokenizer, optimizer, epoch, metrics,
thresholds, run_dir, is_best=True
)
# 早停检查
if early_stopping(metrics['macro_f1']):
logger.info("[INFO] Early stopping triggered")
break
# 保存最终检查点
save_checkpoint(
model, tokenizer, optimizer, epoch, metrics,
thresholds, run_dir, is_best=False
)
writer.close()
logger.info("[INFO] Training completed!")
if __name__ == '__main__':
main()
脚本4:评估与可视化分析
内容描述:实现测试集评估、混淆矩阵可视化、错误案例分析、ROC曲线绘制及模型解释性分析(基于注意力权重)。
使用方式 :python 04_evaluation.py --model_path ./models/best_model --test_data ./processed_data/test.json --output_dir ./evaluation_results
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 4: Comprehensive Evaluation and Visualization
多维度安全分类器性能评估与可视化分析系统
Usage:
python 04_evaluation.py \
--model_path ./models/best_model \
--test_data ./processed_data/test.json \
--output_dir ./evaluation_results \
--batch_size 32
"""
import os
import json
import argparse
import logging
from typing import Dict, List, Tuple
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
confusion_matrix, classification_report, roc_curve, auc,
precision_recall_curve, average_precision_score, hamming_loss,
jaccard_score, f1_score
)
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer
# 导入本地模块
from 02_model_definition import HarmfulnessClassifier
from 03_training import HarmfulnessDataset, collate_fn
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 类别定义
CATEGORIES = ['violence', 'hate_speech', 'sexual_content', 'self_harm']
COLORS = ['#e74c3c', '#e67e22', '#f39c12', '#8e44ad']
def load_model_and_config(model_path: str, device: torch.device):
"""加载训练好的模型与配置"""
logger.info(f"[INFO] Loading model from {model_path}")
# 加载配置
config_path = os.path.join(model_path, 'classifier_config.json')
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
thresholds = config.get('optimal_thresholds', [0.5, 0.5, 0.5, 0.5])
else:
thresholds = [0.5, 0.5, 0.5, 0.5]
# 加载模型
tokenizer = RobertaTokenizer.from_pretrained(model_path)
model = HarmfulnessClassifier.from_pretrained(model_path)
model.to(device)
model.eval()
return model, tokenizer, thresholds
def evaluate_model(
model: HarmfulnessClassifier,
dataloader: DataLoader,
device: torch.device,
thresholds: List[float]
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
"""
执行模型推理,收集预测结果
Returns:
all_probs: 预测概率 [N, 4]
all_labels: 真实标签 [N, 4]
all_texts: 原始文本列表
"""
all_probs = []
all_labels = []
all_texts = []
with torch.no_grad():
for batch in tqdm(dataloader, desc='Evaluating'):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels']
outputs = model(input_ids, attention_mask)
probs = outputs['probs'].cpu().numpy()
all_probs.append(probs)
all_labels.append(labels.numpy())
all_texts.extend(batch['texts'])
all_probs = np.concatenate(all_probs, axis=0)
all_labels = np.concatenate(all_labels, axis=0)
return all_probs, all_labels, all_texts
def plot_confusion_matrices(
y_true: np.ndarray,
y_pred: np.ndarray,
output_dir: str
):
"""绘制每个类别的混淆矩阵"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for idx, (cat, color) in enumerate(zip(CATEGORIES, COLORS)):
cm = confusion_matrix(y_true[:, idx], y_pred[:, idx])
# 归一化
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
ax=axes[idx],
cbar=True,
square=True,
xticklabels=['Safe', 'Harmful'],
yticklabels=['Safe', 'Harmful']
)
axes[idx].set_title(f'{cat.replace("_", " ").title()}\nAccuracy: {np.trace(cm)/np.sum(cm):.3f}')
axes[idx].set_ylabel('True Label')
axes[idx].set_xlabel('Predicted Label')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'confusion_matrices.png'), dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"[SAVE] Confusion matrices saved")
def plot_roc_curves(
y_true: np.ndarray,
y_probs: np.ndarray,
output_dir: str
):
"""绘制ROC曲线与AUC"""
fig, ax = plt.subplots(figsize=(10, 8))
for idx, (cat, color) in enumerate(zip(CATEGORIES, COLORS)):
fpr, tpr, _ = roc_curve(y_true[:, idx], y_probs[:, idx])
roc_auc = auc(fpr, tpr)
ax.plot(
fpr, tpr,
color=color,
lw=2,
label=f'{cat.replace("_", " ").title()} (AUC = {roc_auc:.3f})'
)
ax.plot([0, 1], [0, 1], 'k--', lw=1, label='Random Classifier')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curves for Multi-label Harmfulness Classification')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'roc_curves.png'), dpi=300, bbox_inches='tight')
plt.close()
def plot_precision_recall_curves(
y_true: np.ndarray,
y_probs: np.ndarray,
output_dir: str
):
"""绘制精确率-召回率曲线"""
fig, ax = plt.subplots(figsize=(10, 8))
for idx, (cat, color) in enumerate(zip(CATEGORIES, COLORS)):
precision, recall, _ = precision_recall_curve(y_true[:, idx], y_probs[:, idx])
ap = average_precision_score(y_true[:, idx], y_probs[:, idx])
ax.plot(
recall, precision,
color=color,
lw=2,
label=f'{cat.replace("_", " ").title()} (AP = {ap:.3f})'
)
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curves')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'pr_curves.png'), dpi=300, bbox_inches='tight')
plt.close()
def plot_threshold_sensitivity(
y_true: np.ndarray,
y_probs: np.ndarray,
output_dir: str
):
"""分析阈值敏感度,展示F1随阈值变化"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()
thresholds_range = np.linspace(0.1, 0.9, 50)
for idx, (cat, color) in enumerate(zip(CATEGORIES, COLORS)):
f1_scores = []
for thresh in thresholds_range:
preds = (y_probs[:, idx] >= thresh).astype(int)
f1 = f1_score(y_true[:, idx], preds, zero_division=0)
f1_scores.append(f1)
axes[idx].plot(thresholds_range, f1_scores, color=color, lw=2)
axes[idx].set_xlabel('Threshold')
axes[idx].set_ylabel('F1 Score')
axes[idx].set_title(f'{cat.replace("_", " ").title()}')
axes[idx].grid(True, alpha=0.3)
# 标记最佳阈值
best_idx = np.argmax(f1_scores)
best_thresh = thresholds_range[best_idx]
best_f1 = f1_scores[best_idx]
axes[idx].axvline(best_thresh, color='red', linestyle='--', alpha=0.7)
axes[idx].scatter([best_thresh], [best_f1], color='red', s=100, zorder=5)
axes[idx].text(best_thresh, best_f1, f' ({best_thresh:.2f}, {best_f1:.3f})',
va='bottom', ha='left')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'threshold_analysis.png'), dpi=300, bbox_inches='tight')
plt.close()
def analyze_errors(
texts: List[str],
y_true: np.ndarray,
y_pred: np.ndarray,
y_probs: np.ndarray,
output_dir: str
):
"""错误案例分析:按类别分析假阳性和假阴性"""
error_analysis = defaultdict(lambda: {'false_positives': [], 'false_negatives': []})
for i, text in enumerate(texts):
for j, cat in enumerate(CATEGORIES):
true_label = y_true[i, j]
pred_label = y_pred[i, j]
prob = y_probs[i, j]
if true_label == 0 and pred_label == 1:
error_analysis[cat]['false_positives'].append({
'text': text,
'confidence': float(prob)
})
elif true_label == 1 and pred_label == 0:
error_analysis[cat]['false_negatives'].append({
'text': text,
'confidence': float(prob)
})
# 保存错误案例(每类最多10个高置信度错误)
report_path = os.path.join(output_dir, 'error_analysis.json')
filtered_errors = {}
for cat, errors in error_analysis.items():
filtered_errors[cat] = {
'false_positives': sorted(errors['false_positives'],
key=lambda x: x['confidence'], reverse=True)[:10],
'false_negatives': sorted(errors['false_negatives'],
key=lambda x: x['confidence'], reverse=True)[:10],
'fp_count': len(errors['false_positives']),
'fn_count': len(errors['false_negatives'])
}
with open(report_path, 'w', encoding='utf-8') as f:
json.dump(filtered_errors, f, ensure_ascii=False, indent=2)
logger.info(f"[SAVE] Error analysis saved to {report_path}")
return filtered_errors
def plot_class_distribution_comparison(
y_true: np.ndarray,
y_pred: np.ndarray,
output_dir: str
):
"""比较真实标签与预测标签的分布"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
true_counts = y_true.sum(axis=0)
pred_counts = y_pred.sum(axis=0)
x = np.arange(len(CATEGORIES))
width = 0.35
bars1 = ax1.bar(x - width/2, true_counts, width, label='True', color='steelblue', alpha=0.8)
bars2 = ax1.bar(x + width/2, pred_counts, width, label='Predicted', color='coral', alpha=0.8)
ax1.set_ylabel('Count')
ax1.set_title('Class Distribution: True vs Predicted')
ax1.set_xticks(x)
ax1.set_xticklabels([c.replace('_', '\n') for c in CATEGORIES], rotation=0)
ax1.legend()
# 添加数值标签
for bars in [bars1, bars2]:
for bar in bars:
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height,
f'{int(height)}', ha='center', va='bottom', fontsize=9)
# 预测分布饼图
total_preds = pred_counts.sum()
if total_preds > 0:
ax2.pie(pred_counts, labels=[c.replace('_', ' ').title() for c in CATEGORIES],
colors=COLORS, autopct='%1.1f%%', startangle=90)
ax2.set_title('Predicted Harmfulness Distribution')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'distribution_comparison.png'), dpi=300, bbox_inches='tight')
plt.close()
def generate_classification_report(
y_true: np.ndarray,
y_pred: np.ndarray,
output_dir: str
):
"""生成分类报告"""
report = {}
# 每个类别的详细报告
for i, cat in enumerate(CATEGORIES):
report[cat] = {
'precision': float(precision_score(y_true[:, i], y_pred[:, i], zero_division=0)),
'recall': float(recall_score(y_true[:, i], y_pred[:, i], zero_division=0)),
'f1_score': float(f1_score(y_true[:, i], y_pred[:, i], zero_division=0)),
'support': int(y_true[:, i].sum())
}
# 宏平均
macro_f1 = np.mean([report[cat]['f1_score'] for cat in CATEGORIES])
macro_precision = np.mean([report[cat]['precision'] for cat in CATEGORIES])
macro_recall = np.mean([report[cat]['recall'] for cat in CATEGORIES])
report['macro_avg'] = {
'precision': float(macro_precision),
'recall': float(macro_recall),
'f1_score': float(macro_f1)
}
# 多标签特定指标
report['hamming_loss'] = float(hamming_loss(y_true, y_pred))
report['jaccard_score'] = float(jaccard_score(y_true, y_pred, average='macro'))
# 保存为JSON和文本
with open(os.path.join(output_dir, 'classification_report.json'), 'w') as f:
json.dump(report, f, indent=2)
# 文本格式报告
with open(os.path.join(output_dir, 'classification_report.txt'), 'w') as f:
f.write("HARMFULNESS CLASSIFIER EVALUATION REPORT\n")
f.write("=" * 50 + "\n\n")
for cat in CATEGORIES:
f.write(f"{cat.upper().replace('_', ' ')}:\n")
f.write(f" Precision: {report[cat]['precision']:.4f}\n")
f.write(f" Recall: {report[cat]['recall']:.4f}\n")
f.write(f" F1-Score: {report[cat]['f1_score']:.4f}\n")
f.write(f" Support: {report[cat]['support']}\n\n")
f.write("-" * 50 + "\n")
f.write(f"MACRO AVERAGE:\n")
f.write(f" Precision: {report['macro_avg']['precision']:.4f}\n")
f.write(f" Recall: {report['macro_avg']['recall']:.4f}\n")
f.write(f" F1-Score: {report['macro_avg']['f1_score']:.4f}\n\n")
f.write(f"Hamming Loss: {report['hamming_loss']:.4f}\n")
f.write(f"Jaccard Index: {report['jaccard_score']:.4f}\n")
# 检查是否达到F1>0.85要求
all_above_threshold = all(report[cat]['f1_score'] > 0.85 for cat in CATEGORIES)
f.write(f"\nAll categories F1 > 0.85: {all_above_threshold}\n")
if not all_above_threshold:
f.write("WARNING: Some categories do not meet the F1 > 0.85 requirement!\n")
logger.info(f"[INFO] Macro F1: {macro_f1:.4f}")
return report
def main():
parser = argparse.ArgumentParser(description='Evaluate harmfulness classifier')
parser.add_argument('--model_path', type=str, required=True, help='Path to trained model')
parser.add_argument('--test_data', type=str, required=True, help='Path to test JSON')
parser.add_argument('--output_dir', type=str, default='./evaluation_results', help='Output directory')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for inference')
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型
model, tokenizer, thresholds = load_model_and_config(args.model_path, device)
logger.info(f"[INFO] Using thresholds: {thresholds}")
# 加载测试数据
with open(args.test_data, 'r') as f:
test_data = json.load(f)
logger.info(f"[INFO] Loaded {len(test_data)} test samples")
# 创建数据加载器
test_dataset = HarmfulnessDataset(test_data, tokenizer)
test_loader = DataLoader(
test_dataset,
batch_size=args.batch_size,
collate_fn=collate_fn,
shuffle=False
)
# 推理
y_probs, y_true, texts = evaluate_model(model, test_loader, device, thresholds)
y_pred = (y_probs >= np.array(thresholds)).astype(int)
# 生成分类报告
report = generate_classification_report(y_true, y_pred, args.output_dir)
# 可视化
logger.info("[INFO] Generating visualizations...")
plot_confusion_matrices(y_true, y_pred, args.output_dir)
plot_roc_curves(y_true, y_probs, args.output_dir)
plot_precision_recall_curves(y_true, y_probs, args.output_dir)
plot_threshold_sensitivity(y_true, y_probs, args.output_dir)
plot_class_distribution_comparison(y_true, y_pred, args.output_dir)
# 错误分析
logger.info("[INFO] Analyzing error cases...")
analyze_errors(texts, y_true, y_pred, y_probs, args.output_dir)
logger.info("[INFO] Evaluation completed!")
# 打印关键结果
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
for cat in CATEGORIES:
f1 = report[cat]['f1_score']
status = "✓" if f1 > 0.85 else "✗"
print(f"{status} {cat:20s}: F1 = {f1:.4f}")
print("-"*60)
print(f"Macro Average F1: {report['macro_avg']['f1_score']:.4f}")
print("="*60)
if __name__ == '__main__':
main()
脚本5:实时推理与部署API
内容描述:实现基于优化后模型的实时推理接口,支持批量文本检测、置信度校准与解释性输出。提供Flask API封装示例。
使用方式 :python 05_inference_api.py --model_path ./models/best_model --port 5000
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 5: Real-time Inference and Deployment API
有害性分类器实时推理与部署脚本,支持单条/批量检测与API服务
Usage:
# 单条测试
python 05_inference_api.py --model_path ./models/best_model --test "How to make a bomb"
# 批量处理
python 05_inference_api.py --model_path ./models/best_model --batch_file ./inputs.json --output ./results.json
# 启动API服务
python 05_inference_api.py --model_path ./models/best_model --serve --port 5000
"""
import os
import json
import argparse
import logging
from typing import Dict, List, Union, Optional
from dataclasses import dataclass
import torch
import numpy as np
from transformers import RobertaTokenizer
# 导入本地模块
from 02_model_definition import HarmfulnessClassifier
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
CATEGORIES = ['violence', 'hate_speech', 'sexual_content', 'self_harm']
@dataclass
class DetectionResult:
"""检测结果数据结构"""
text: str
is_harmful: bool
categories: Dict[str, bool]
probabilities: Dict[str, float]
confidence: float # 最高概率值
threshold_used: Dict[str, float]
class HarmfulnessDetector:
"""有害性检测器封装类"""
def __init__(
self,
model_path: str,
device: str = 'auto',
max_length: int = 512
):
if device == 'auto':
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
logger.info(f"[INIT] Loading model from {model_path}")
self.tokenizer = RobertaTokenizer.from_pretrained(model_path)
self.model = HarmfulnessClassifier.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.max_length = max_length
# 加载最优阈值
config_path = os.path.join(model_path, 'classifier_config.json')
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
self.thresholds = config.get('optimal_thresholds', [0.5, 0.5, 0.5, 0.5])
else:
self.thresholds = [0.5, 0.5, 0.5, 0.5]
logger.info(f"[INIT] Thresholds: {self.thresholds}")
logger.info(f"[INIT] Device: {self.device}")
def predict(self, text: str) -> DetectionResult:
"""单条文本预测"""
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids, attention_mask)
probs = outputs['probs'].cpu().numpy()[0]
# 应用阈值
predictions = (probs >= np.array(self.thresholds)).astype(int)
categories = {
cat: bool(predictions[i])
for i, cat in enumerate(CATEGORIES)
}
probabilities = {
cat: float(probs[i])
for i, cat in enumerate(CATEGORIES)
}
is_harmful = bool(predictions.any())
confidence = float(probs.max())
return DetectionResult(
text=text,
is_harmful=is_harmful,
categories=categories,
probabilities=probabilities,
confidence=confidence,
threshold_used={
cat: self.thresholds[i]
for i, cat in enumerate(CATEGORIES)
}
)
def predict_batch(self, texts: List[str]) -> List[DetectionResult]:
"""批量预测(优化版,使用真实批处理)"""
encodings = self.tokenizer(
texts,
max_length=self.max_length,
padding=True,
truncation=True,
return_tensors='pt'
)
input_ids = encodings['input_ids'].to(self.device)
attention_mask = encodings['attention_mask'].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids, attention_mask)
probs = outputs['probs'].cpu().numpy()
results = []
for i, text in enumerate(texts):
predictions = (probs[i] >= np.array(self.thresholds)).astype(int)
categories = {cat: bool(predictions[j]) for j, cat in enumerate(CATEGORIES)}
probabilities = {cat: float(probs[i][j]) for j, cat in enumerate(CATEGORIES)}
results.append(DetectionResult(
text=text,
is_harmful=bool(predictions.any()),
categories=categories,
probabilities=probabilities,
confidence=float(probs[i].max()),
threshold_used={cat: self.thresholds[j] for j, cat in enumerate(CATEGORIES)}
))
return results
def get_attention_weights(self, text: str) -> Optional[np.ndarray]:
"""
获取注意力权重用于解释性分析(需模型支持输出attention)
"""
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
with torch.no_grad():
outputs = self.model.roberta(
input_ids,
attention_mask=attention_mask,
output_attentions=True
)
# 获取最后一层的注意力 [batch, heads, seq, seq]
attentions = outputs.attentions[-1].cpu().numpy()
# 平均所有头并取[CLS]对其他token的注意力
cls_attention = attentions[0, :, 0, :].mean(axis=0)
return cls_attention, self.tokenizer.convert_ids_to_tokens(input_ids[0])
def format_result(result: DetectionResult, verbose: bool = False) -> str:
"""格式化输出检测结果"""
lines = [
f"Text: {result.text[:100]}{'...' if len(result.text) > 100 else ''}",
f"Harmful: {'YES' if result.is_harmful else 'NO'} (confidence: {result.confidence:.4f})",
"Categories detected:"
]
for cat in CATEGORIES:
status = "✓" if result.categories[cat] else "✗"
prob = result.probabilities[cat]
thresh = result.threshold_used[cat]
lines.append(f" {status} {cat:20s}: {prob:.4f} (threshold: {thresh:.2f})")
if verbose:
lines.append(f"\nThresholds: {result.threshold_used}")
return "\n".join(lines)
def create_app(detector: HarmfulnessDetector):
"""创建Flask应用"""
try:
from flask import Flask, request, jsonify
except ImportError:
logger.error("[ERROR] Flask not installed. Run: pip install flask")
raise
app = Flask(__name__)
@app.route('/health', methods=['GET'])
def health():
return jsonify({'status': 'healthy', 'model': 'harmfulness_classifier'})
@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.get_json()
if 'text' not in data:
return jsonify({'error': 'Missing text field'}), 400
result = detector.predict(data['text'])
return jsonify({
'is_harmful': result.is_harmful,
'categories': result.categories,
'probabilities': result.probabilities,
'confidence': result.confidence
})
except Exception as e:
logger.error(f"[API ERROR] {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/predict_batch', methods=['POST'])
def predict_batch():
try:
data = request.get_json()
if 'texts' not in data or not isinstance(data['texts'], list):
return jsonify({'error': 'Missing texts field (should be list)'}), 400
texts = data['texts'][:100] # 限制批量大小
results = detector.predict_batch(texts)
return jsonify({
'results': [
{
'is_harmful': r.is_harmful,
'categories': r.categories,
'probabilities': r.probabilities,
'confidence': r.confidence
}
for r in results
]
})
except Exception as e:
logger.error(f"[API ERROR] {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/explain', methods=['POST'])
def explain():
try:
data = request.get_json()
text = data.get('text', '')
# 获取预测
result = detector.predict(text)
# 获取注意力权重(用于解释)
try:
attention, tokens = detector.get_attention_weights(text)
# 过滤特殊token并获取前5重要token
important_indices = np.argsort(attention)[-6:-1] # 排除[CLS]自身
explanation = [
{'token': tokens[i], 'attention': float(attention[i])}
for i in important_indices if tokens[i] not in ['<s>', '</s>', '<pad>']
]
except Exception as e:
explanation = []
logger.warning(f"[EXPLAIN] Could not get attention: {e}")
return jsonify({
'prediction': {
'is_harmful': result.is_harmful,
'categories': result.categories
},
'important_tokens': explanation
})
except Exception as e:
return jsonify({'error': str(e)}), 500
return app
def main():
parser = argparse.ArgumentParser(description='Harmfulness Detector Inference')
parser.add_argument('--model_path', type=str, required=True, help='Path to trained model')
parser.add_argument('--test', type=str, help='Single text to test')
parser.add_argument('--batch_file', type=str, help='JSON file with list of texts')
parser.add_argument('--output', type=str, help='Output file for batch results')
parser.add_argument('--serve', action='store_true', help='Start API server')
parser.add_argument('--port', type=int, default=5000, help='API server port')
parser.add_argument('--device', type=str, default='auto', help='Device to use')
args = parser.parse_args()
# 初始化检测器
detector = HarmfulnessDetector(args.model_path, device=args.device)
if args.serve:
logger.info(f"[SERVE] Starting API server on port {args.port}")
app = create_app(detector)
app.run(host='0.0.0.0', port=args.port, debug=False)
elif args.test:
logger.info("[TEST] Running single prediction")
result = detector.predict(args.test)
print("\n" + "="*60)
print(format_result(result, verbose=True))
print("="*60)
elif args.batch_file:
logger.info(f"[BATCH] Processing {args.batch_file}")
with open(args.batch_file, 'r') as f:
texts = json.load(f)
if not isinstance(texts, list):
texts = [texts]
results = detector.predict_batch(texts)
output_data = [
{
'text': r.text,
'is_harmful': r.is_harmful,
'categories': r.categories,
'probabilities': r.probabilities
}
for r in results
]
if args.output:
with open(args.output, 'w') as f:
json.dump(output_data, f, indent=2)
logger.info(f"[SAVE] Results saved to {args.output}")
# 打印摘要
harmful_count = sum(1 for r in results if r.is_harmful)
print(f"\nProcessed {len(results)} texts")
print(f"Harmful detected: {harmful_count} ({100*harmful_count/len(results):.2f}%)")
# 各类别统计
for cat in CATEGORIES:
count = sum(1 for r in results if r.categories[cat])
print(f" - {cat}: {count}")
else:
print("请指定--test, --batch_file或--serve参数")
if __name__ == '__main__':
main()