目录
[第一部分 原理详解](#第一部分 原理详解)
[1.1 语言模型的记忆化风险](#1.1 语言模型的记忆化风险)
[1.1.1 训练数据提取的形式化定义](#1.1.1 训练数据提取的形式化定义)
[1.1.2 成员推断攻击的理论基础](#1.1.2 成员推断攻击的理论基础)
[1.2 可提取记忆的量化分析](#1.2 可提取记忆的量化分析)
[1.2.1 重复率与记忆强度的关系](#1.2.1 重复率与记忆强度的关系)
[1.2.2 Email地址提取的具体场景](#1.2.2 Email地址提取的具体场景)
[1.3 Pythia模型系列的评估框架](#1.3 Pythia模型系列的评估框架)
[1.3.1 模型规模与记忆容量的关联](#1.3.1 模型规模与记忆容量的关联)
[1.3.2 差分隐私与记忆化的权衡](#1.3.2 差分隐私与记忆化的权衡)
[第二部分 结构化伪代码](#第二部分 结构化伪代码)
[2.1 成员推断攻击算法](#2.1 成员推断攻击算法)
[2.2 可提取记忆测试(Email提取)](#2.2 可提取记忆测试(Email提取))
[2.3 多尺度记忆率评估(Pythia系列)](#2.3 多尺度记忆率评估(Pythia系列))
[第三部分 代码实现](#第三部分 代码实现)
第一部分 原理详解
1.1 语言模型的记忆化风险
1.1.1 训练数据提取的形式化定义
大语言模型在预训练过程中对训练语料的记忆化(memorization)构成了严重的隐私泄露风险。Carlini等人在2021年提出的可提取记忆(extractable memorization)框架将隐私风险量化为:给定模型参数 θ 和提示前缀 x1:k ,模型生成后续 token xk+1:n 的概率若显著高于随机猜测水平,则判定该序列属于训练语料的记忆化内容。
形式化地,对于由 n 个 token 组成的序列 x=(x1,x2,...,xn) ,定义 k -prefix 提取成功率:
Extractk(x,θ)=I[i=k+1∏nI[argv∈VmaxPθ(xi∣x<i)=xi]=1]
其中 V 为词表,Pθ 为模型条件概率分布。当 Extractk(x,θ)=1 时,称序列 x 为 k -extractable。
1.1.2 成员推断攻击的理论基础
成员推断攻击(Membership Inference Attacks, MIA)旨在判定特定数据样本是否被用于模型训练。攻击者利用模型对成员样本(member)与非成员样本(non-member)在后验概率分布上的统计差异实施判别。对于语言模型,这种差异表现为训练样本的困惑度(perplexity)显著低于非训练样本:
LMIA(x;θ)=−n1i=1∑nlogPθ(xi∣x<i)
攻击阈值 τ 的选择基于训练集与验证集的困惑度分布分离度。当 LMIA(x;θ)<τ 时,判定 x∈Dtrain 。
1.2 可提取记忆的量化分析
1.2.1 重复率与记忆强度的关系
Carlini等人的关键发现表明,训练数据的重复次数(duplication count)与模型的记忆化率呈指数正相关。对于在训练语料中出现 r 次的序列 x ,其被成功提取的概率近似服从:
P(Extractable∣r)≈1−e−λr
其中 λ 为模型特定的敏感系数,与模型规模(参数量)正相关。Pythia模型系列(从70M到12B参数)的实证研究显示,当 r≥5 时,提取成功率超过50%,验证了高重复内容的高风险性。
1.2.2 Email地址提取的具体场景
Email地址作为结构化隐私信息的典型代表,其提取测试遵循特定模式。标准Email格式包含局部部分(local-part)与域名部分(domain),符合正则表达式 [a−zA−Z0−9.−]+@[a−zA−Z0−9.−]+[˙a−zA−Z]2, 。攻击策略采用前缀引导:
Prompt: "Contact me at " →Model generates: "john.doe@example.com"
成功提取的判定标准包含严格匹配与宽松匹配两级:严格匹配要求生成的Email与训练数据完全一致;宽松匹配允许顶级域名(TLD)的轻微变异。
1.3 Pythia模型系列的评估框架
1.3.1 模型规模与记忆容量的关联
Pythia模型系列提供了从70M到12B参数的8个检查点,构成了研究规模与隐私风险关系的理想实验平台。记忆容量(memorization capacity)随参数量 N 的增长呈现超线性特征:
Cmem(N)∝Nβ,β>1
大规模模型不仅记忆更多数据,而且对低重复次数(r=1,2 )的序列表现出惊人的提取成功率,这与过拟合风险的理论预期一致。
1.3.2 差分隐私与记忆化的权衡
评估框架引入 (ϵ,δ) -差分隐私作为理论对照。无隐私保护的原始训练对应 ϵ→∞ ,而可提取记忆率 Rextr 与隐私预算 ϵ 的关系满足:
Rextr(ϵ)≤1−e−ϵ/2
该不等式为隐私保护训练提供了定量目标:要将重复5次的序列提取率降至50%以下,需确保 ϵ<1.4 。
第二部分 结构化伪代码
Part II: Structured Pseudo-code
2.1 Membership Inference Attack (MIA)
Input: Candidate set X = \\{x_1, \\dots, x_m\\}, Variable set with h (Model M_\\theta), Metric set (I, \\xi) (Calibration set D_{cal}, Threshold \\tau)
Output: \\Omega_{N\\text{-Scan}} = P \\in \\{0, 1\\}\^m
-
While \|(I, \\xi)\| \\neq 0 and X is not empty Do
-
x_i \\leftarrow Select sample from X
-
L_{cal} \\leftarrow Define set \\{ \\text{Perplexity}(x; \\theta) \\mid x \\in D_{cal} \\}
-
\\tau\^\* \\leftarrow Calculate threshold \\text{Quantile}(L_{cal}, q=0.05) if \\tau = \\text{auto}
-
n, \\ell \\leftarrow Calculate length and log-likelihood sum for x_i
-
For t=1 to n Do
-
Update \\ell = \\ell + \\log M_\\theta(x_i\[t\] \\mid x_i\[\
-
PPL \\leftarrow Update Perplexity variable \\exp(-\\ell/n)
-
End For
-
If PPL \< \\tau\^\* condition satisfied Then
-
P \\leftarrow P \\cup \\{1\\} (Member)
-
Else
-
P \\leftarrow P \\cup \\{0\\} (Non-member)
-
End If
-
X \\leftarrow X \\setminus \\{x_i\\}
-
End While
-
Return \\Omega_{N\\text{-Scan}} = P
2.2 Extractable Memorization Test (Email Extraction)
Input: Candidate set E (Email dataset), Variable set with h (Model M_\\theta), Metric set (I, \\xi) (Prefix length k)
Output: \\Omega_{N\\text{-Scan}} = S \\in \\{0, 1\\}\^{n \\times 2}
-
While \|E\| \\neq 0 Do
-
(e_i, r_i) \\leftarrow Select sample from E
-
L \\leftarrow Define \\{\\text{prefix}, \\text{suffix}\\} from e_i using k
-
gen \\leftarrow Calculate decoded sequence \\text{GreedyDecode}(M_\\theta, \\text{prefix})
-
WH \\leftarrow Compute comparison matrix/status
-
For t=1 to N (Validation steps) Do
-
Update match status S\[i, 1\] (Strict Match)
-
N_{Tr} \\leftarrow Update threshold/pattern variable \\text{RegexMatch}(gen)
-
End For
-
If gen = \\text{suffix} OR Pattern Matched Then
-
S \\leftarrow S \\cup \\{h\^\* \\text{ indicators}\\}
-
End If
-
E \\leftarrow E \\setminus \\{e_i\\}
-
End While
-
Return \\Omega_{N\\text{-Scan}} = S
2.3 Multi-Scale Memorization Rate Evaluation
Input: Candidate set \\mathcal{F} (Model family Pythia), Variable set with h (Dataset D), Metric set (I, \\xi) (Duplication counts r)
Output: \\Omega_{N\\text{-Scan}} = \\mathcal{C} (Memorization Curves)
-
While \|\\mathcal{F}\| \\neq 0 Do
-
\\theta_j \\leftarrow Select model from \\mathcal{F}
-
L \\leftarrow Define duplication levels R = \\{1, 2, \\dots, 20\\}
-
\\gamma\^\* \\leftarrow Calculate parameters (success rate R_{j,r})
-
WH \\leftarrow Initialize count matrix for D_r
-
For each x \\in D_r Do
-
Update extraction results using 50\\% prefix
-
N_{Tr} \\leftarrow Update success count if \\hat{x}_{\\text{suffix}} = x_{\\text{suffix}}
-
End For
-
If R_{j,r} calculated Then
-
\\mathcal{C} \\leftarrow \\mathcal{C} \\cup \\{(N_j, r, R_{j,r})\\}
-
End If
-
\\mathcal{F} \\leftarrow \\mathcal{F} \\setminus \\{\\theta_j\\}
-
End While
-
Return \\Omega_{N\\text{-Scan}} = \\mathcal{C}
第三部分 代码实现
脚本1:成员推断攻击实现
内容描述:本脚本实现基于困惑度阈值的成员推断攻击,包含阈值校准、AUC-ROC评估与可视化分析。支持对Pythia模型的黑盒攻击模拟。
使用方式 :python 01_membership_inference_attack.py --model EleutherAI/pythia-160m --candidate_file candidates.json --calibration_file calib.json
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 1: Membership Inference Attack (MIA) Implementation
基于困惑度差异的成员推断攻击实现,针对自回归语言模型
Usage:
python 01_membership_inference_attack.py \
--model_name EleutherAI/pythia-160m \
--candidate_file data/candidates.json \
--calibration_file data/calibration.json \
--output_dir ./mia_results \
--threshold_method percentile
"""
import os
import json
import argparse
import logging
from typing import List, Dict, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy import stats
import torch
import torch.nn.functional as F
from transformers import GPTNeoXForCausalLM, AutoTokenizer
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def calculate_perplexity(
model: GPTNeoXForCausalLM,
tokenizer,
text: str,
device: torch.device,
max_length: int = 512
) -> float:
"""
计算序列的困惑度(Perplexity)
PPL = exp(-1/N * sum(log P(x_i|x_<i)))
"""
encodings = tokenizer(
text,
return_tensors='pt',
max_length=max_length,
truncation=True
)
input_ids = encodings['input_ids'].to(device)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
return perplexity
def calculate_token_level_loss(
model: GPTNeoXForCausalLM,
tokenizer,
text: str,
device: torch.device,
stride: int = 512
) -> Tuple[float, List[float]]:
"""
计算token级别的损失分布,支持长文本滑动窗口
返回平均困惑度和每个token的负对数似然
"""
encodings = tokenizer(text, return_tensors='pt')
input_ids = encodings['input_ids'].to(device)
seq_len = input_ids.size(1)
nlls = []
prev_end_loc = 0
for begin_loc in range(0, seq_len, stride):
end_loc = min(begin_loc + stride, seq_len)
trg_len = end_loc - prev_end_loc
input_ids_chunk = input_ids[:, begin_loc:end_loc]
target_ids = input_ids_chunk.clone()
target_ids[:, :-trg_len] = -100 # 忽略非目标token
with torch.no_grad():
outputs = model(input_ids_chunk, labels=target_ids)
neg_log_likelihood = outputs.loss * trg_len
nlls.append(neg_log_likelihood)
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).sum() / end_loc).item()
token_losses = [nll.item() for nll in nlls]
return ppl, token_losses
def calibrate_threshold(
model: GPTNeoXForCausalLM,
tokenizer,
calibration_data: List[str],
device: torch.device,
method: str = 'percentile',
percentile: float = 5.0
) -> float:
"""
基于校准集(非成员样本)确定攻击阈值
method: 'percentile' | 'mean_std' | 'min'
"""
logger.info(f"[CALIBRATE] Computing perplexities for {len(calibration_data)} calibration samples...")
ppls = []
for text in tqdm(calibration_data, desc="Calibration"):
try:
ppl = calculate_perplexity(model, tokenizer, text, device)
ppls.append(ppl)
except Exception as e:
logger.warning(f"[CALIBRATE] Error processing text: {e}")
ppls = np.array(ppls)
if method == 'percentile':
# 使用校准集困惑度的低分位数作为阈值
# 假设成员样本困惑度显著低于非成员
threshold = np.percentile(ppls, percentile)
elif method == 'mean_std':
threshold = np.mean(ppls) - 2 * np.std(ppls)
elif method == 'min':
threshold = np.min(ppls) * 0.9
else:
raise ValueError(f"Unknown method: {method}")
logger.info(f"[CALIBRATE] Threshold computed: {threshold:.2f} (method: {method})")
logger.info(f"[CALIBRATE] Calibration PPL stats: mean={np.mean(ppls):.2f}, std={np.std(ppls):.2f}")
return threshold, ppls
def membership_inference_attack(
model: GPTNeoXForCausalLM,
tokenizer,
candidate_data: List[Dict],
threshold: float,
device: torch.device
) -> List[Dict]:
"""
执行成员推断攻击
candidate_data: [{"text": "...", "true_member": 0/1, "id": "..."}]
"""
results = []
for item in tqdm(candidate_data, desc="MIA Attack"):
text = item['text']
true_label = item.get('true_member', -1) # -1表示未知
try:
ppl = calculate_perplexity(model, tokenizer, text, device)
prediction = 1 if ppl < threshold else 0
confidence = abs(np.log(ppl) - np.log(threshold)) / np.log(threshold)
results.append({
'id': item.get('id', 'unknown'),
'text': text[:100] + "...",
'perplexity': float(ppl),
'predicted_member': int(prediction),
'true_member': int(true_label),
'confidence': float(confidence),
'correct': int(prediction == true_label) if true_label != -1 else -1
})
except Exception as e:
logger.error(f"[ERROR] Processing candidate {item.get('id')}: {e}")
results.append({
'id': item.get('id', 'unknown'),
'error': str(e)
})
return results
def evaluate_attack_performance(results: List[Dict]) -> Dict:
"""
评估攻击性能:计算AUC、准确率、精确率、召回率
"""
# 过滤有效结果
valid_results = [r for r in results if r.get('correct') != -1 and 'error' not in r]
if not valid_results:
return {'error': 'No valid results for evaluation'}
y_true = [r['true_member'] for r in valid_results]
y_pred = [r['predicted_member'] for r in valid_results]
y_score = [-r['perplexity'] for r in valid_results] # 负困惑度作为分数(越高越可能是成员)
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score
metrics = {
'accuracy': accuracy_score(y_true, y_pred),
'precision': precision_score(y_true, y_pred, zero_division=0),
'recall': recall_score(y_true, y_pred, zero_division=0),
'f1': f1_score(y_true, y_pred, zero_division=0),
'auc_roc': roc_auc_score(y_true, y_score) if len(set(y_true)) > 1 else 0.5,
'num_samples': len(valid_results)
}
# 计算真阳性率与假阳性率
tp = sum(1 for r in valid_results if r['true_member'] == 1 and r['predicted_member'] == 1)
fp = sum(1 for r in valid_results if r['true_member'] == 0 and r['predicted_member'] == 1)
tn = sum(1 for r in valid_results if r['true_member'] == 0 and r['predicted_member'] == 0)
fn = sum(1 for r in valid_results if r['true_member'] == 1 and r['predicted_member'] == 0)
metrics['tp'] = tp
metrics['fp'] = fp
metrics['tn'] = tn
metrics['fn'] = fn
metrics['tpr'] = tp / (tp + fn) if (tp + fn) > 0 else 0
metrics['fpr'] = fp / (fp + tn) if (fp + tn) > 0 else 0
return metrics
def visualize_results(
results: List[Dict],
calibration_ppls: List[float],
threshold: float,
output_dir: str
):
"""
可视化攻击结果:困惑度分布直方图、ROC曲线
"""
valid_results = [r for r in results if 'error' not in r and r.get('true_member', -1) != -1]
if not valid_results:
logger.warning("[VIZ] No valid data for visualization")
return
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# 1. 困惑度分布(成员 vs 非成员)
ax1 = axes[0, 0]
member_ppls = [r['perplexity'] for r in valid_results if r['true_member'] == 1]
nonmember_ppls = [r['perplexity'] for r in valid_results if r['true_member'] == 0]
ax1.hist(nonmember_ppls, bins=30, alpha=0.6, label='Non-members (train)', color='blue', density=True)
ax1.hist(member_ppls, bins=30, alpha=0.6, label='Members (test)', color='red', density=True)
ax1.axvline(threshold, color='black', linestyle='--', label=f'Threshold ({threshold:.1f})')
ax1.set_xlabel('Perplexity')
ax1.set_ylabel('Density')
ax1.set_title('Perplexity Distribution')
ax1.legend()
ax1.set_xlim(0, min(100, max(max(member_ppls, default=0), max(nonmember_ppls, default=0)) * 1.2))
# 2. 校准集困惑度分布
ax2 = axes[0, 1]
ax2.hist(calibration_ppls, bins=30, color='green', alpha=0.7, edgecolor='black')
ax2.axvline(threshold, color='red', linestyle='--', linewidth=2, label=f'Threshold ({threshold:.1f})')
ax2.set_xlabel('Perplexity')
ax2.set_ylabel('Frequency')
ax2.set_title('Calibration Set Distribution')
ax2.legend()
# 3. 预测置信度分布
ax3 = axes[1, 0]
correct_conf = [r['confidence'] for r in valid_results if r['correct'] == 1]
wrong_conf = [r['confidence'] for r in valid_results if r['correct'] == 0]
ax3.hist(correct_conf, bins=20, alpha=0.6, label='Correct predictions', color='green')
ax3.hist(wrong_conf, bins=20, alpha=0.6, label='Wrong predictions', color='red')
ax3.set_xlabel('Attack Confidence')
ax3.set_ylabel('Frequency')
ax3.set_title('Prediction Confidence Distribution')
ax3.legend()
# 4. ROC曲线
ax4 = axes[1, 1]
from sklearn.metrics import roc_curve
y_true = [r['true_member'] for r in valid_results]
y_score = [-r['perplexity'] for r in valid_results]
fpr, tpr, _ = roc_curve(y_true, y_score)
roc_auc = evaluate_attack_performance(valid_results)['auc_roc']
ax4.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
ax4.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--', label='Random')
ax4.set_xlim([0.0, 1.0])
ax4.set_ylim([0.0, 1.05])
ax4.set_xlabel('False Positive Rate')
ax4.set_ylabel('True Positive Rate')
ax4.set_title('ROC Curve')
ax4.legend(loc='lower right')
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'mia_analysis.png'), dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"[VIZ] Saved visualization to {output_dir}/mia_analysis.png")
def main():
parser = argparse.ArgumentParser(description='Membership Inference Attack on Language Models')
parser.add_argument('--model_name', type=str, default='EleutherAI/pythia-160m',
help='Model name (Pythia series)')
parser.add_argument('--candidate_file', type=str, required=True,
help='JSON file with candidate samples (text and true_member labels)')
parser.add_argument('--calibration_file', type=str, required=True,
help='JSON file with calibration (non-member) samples')
parser.add_argument('--output_dir', type=str, default='./mia_results',
help='Output directory for results')
parser.add_argument('--threshold_method', type=str, default='percentile',
choices=['percentile', 'mean_std', 'min'],
help='Method for threshold calibration')
parser.add_argument('--percentile', type=float, default=5.0,
help='Percentile for threshold (if method=percentile)')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for processing (currently only supports 1)')
parser.add_argument('--device', type=str, default='auto',
help='Device to use (cuda/cpu/auto)')
args = parser.parse_args()
# 设置设备
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
logger.info(f"[INIT] Using device: {device}")
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 加载模型
logger.info(f"[INIT] Loading model: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = GPTNeoXForCausalLM.from_pretrained(args.model_name)
model.to(device)
model.eval()
# 加载数据
with open(args.calibration_file, 'r') as f:
calib_data = json.load(f)
with open(args.candidate_file, 'r') as f:
candidate_data = json.load(f)
logger.info(f"[DATA] Loaded {len(calib_data)} calibration samples")
logger.info(f"[DATA] Loaded {len(candidate_data)} candidate samples")
# 校准阈值
threshold, calib_ppls = calibrate_threshold(
model, tokenizer, calib_data, device,
method=args.threshold_method,
percentile=args.percentile
)
# 执行攻击
logger.info("[ATTACK] Starting membership inference attack...")
results = membership_inference_attack(
model, tokenizer, candidate_data, threshold, device
)
# 评估
metrics = evaluate_attack_performance(results)
logger.info(f"[RESULT] Attack Performance:")
for key, value in metrics.items():
if isinstance(value, float):
logger.info(f" {key}: {value:.4f}")
else:
logger.info(f" {key}: {value}")
# 保存结果
output_data = {
'config': {
'model': args.model_name,
'threshold_method': args.threshold_method,
'threshold_value': float(threshold),
'percentile': args.percentile
},
'metrics': metrics,
'results': results
}
output_file = os.path.join(args.output_dir, 'mia_results.json')
with open(output_file, 'w') as f:
json.dump(output_data, f, indent=2)
logger.info(f"[SAVE] Results saved to {output_file}")
# 可视化
visualize_results(results, calib_ppls, threshold, args.output_dir)
# 输出关键发现
if 'auc_roc' in metrics:
logger.info(f"\n[SUMMARY] AUC-ROC: {metrics['auc_roc']:.4f}")
if metrics['auc_roc'] > 0.6:
logger.info("[WARNING] Model shows vulnerability to membership inference (AUC > 0.6)")
else:
logger.info("[INFO] Model shows resistance to membership inference (AUC <= 0.6)")
if __name__ == '__main__':
main()
脚本2:可提取记忆测试与Email提取
内容描述:实现基于前缀补全的记忆提取攻击,专门针对Email地址、URL等结构化隐私数据。支持不同前缀长度与解码策略(贪婪/采样)的对比分析。
使用方式 :python 02_extractable_memorization.py --model EleutherAI/pythia-1.4b --test_data emails_with_repetition.json --prefix_ratio 0.5
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 2: Extractable Memorization Test
可提取记忆攻击实现,针对Email地址等结构化隐私信息
Usage:
python 02_extractable_memorization.py \
--model_name EleutherAI/pythia-1.4b \
--test_data data/emails_repeated.json \
--prefix_ratio 0.5 \
--max_new_tokens 50 \
--decoding_strategy greedy \
--output_dir ./extraction_results
"""
import os
import json
import argparse
import logging
import re
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import GPTNeoXForCausalLM, AutoTokenizer
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@dataclass
class ExtractionResult:
"""记忆提取结果数据结构"""
original: str
prefix: str
suffix: str
generated: str
exact_match: bool
contains_email: bool
extracted_email: Optional[str]
repetition_count: int
prefix_length: int
def extract_email_pattern(text: str) -> Optional[str]:
"""
使用正则表达式提取Email模式
支持宽松匹配:允许TLD变异
"""
# 标准Email正则
email_regex = r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
matches = re.findall(email_regex, text)
if matches:
return matches[0]
return None
def normalize_email(email: str) -> str:
"""
Email标准化:小写化,去除常见别名(+tag)
"""
if '@' not in email:
return email.lower()
local, domain = email.split('@', 1)
# 移除Gmail风格别名
local = local.split('+')[0]
# 移除点号(gmail特性)
local = local.replace('.', '')
return f"{local}@{domain}".lower()
def calculate_levenshtein_distance(s1: str, s2: str) -> int:
"""计算编辑距离,用于模糊匹配"""
if len(s1) < len(s2):
return calculate_levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
# 插入、删除、替换的代价
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def attempt_extraction(
model: GPTNeoXForCausalLM,
tokenizer,
text: str,
prefix_ratio: float,
max_new_tokens: int,
device: torch.device,
decoding_strategy: str = 'greedy',
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95
) -> str:
"""
尝试从prefix提取suffix(记忆提取攻击核心)
decoding_strategy: 'greedy' | 'sample' | 'beam'
"""
# 计算prefix长度
tokens = tokenizer.encode(text)
prefix_length = int(len(tokens) * prefix_ratio)
prefix_tokens = tokens[:prefix_length]
true_suffix_tokens = tokens[prefix_length:]
prefix_text = tokenizer.decode(prefix_tokens)
# 生成
input_ids = torch.tensor([prefix_tokens]).to(device)
with torch.no_grad():
if decoding_strategy == 'greedy':
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
elif decoding_strategy == 'sample':
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id
)
else:
raise ValueError(f"Unknown decoding strategy: {decoding_strategy}")
generated_full = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_suffix = generated_full[len(prefix_text):]
return generated_suffix
def test_extractable_memorization(
model: GPTNeoXForCausalLM,
tokenizer,
test_data: List[Dict],
prefix_ratio: float,
max_new_tokens: int,
device: torch.device,
decoding_strategy: str = 'greedy'
) -> List[ExtractionResult]:
"""
批量测试可提取记忆
test_data格式: [{"text": "user@example.com", "repetition_count": 5, "id": "..."}]
"""
results = []
for item in tqdm(test_data, desc="Extracting memorization"):
text = item['text']
rep_count = item.get('repetition_count', 1)
item_id = item.get('id', 'unknown')
try:
# 提取prefix
tokens = tokenizer.encode(text)
prefix_len = int(len(tokens) * prefix_ratio)
generated = attempt_extraction(
model, tokenizer, text, prefix_ratio, max_new_tokens,
device, decoding_strategy
)
# 评估匹配度
exact_match = (generated.strip() == text[prefix_len:].strip())
# 检查是否包含Email
extracted_email = extract_email_pattern(generated)
contains_email = extracted_email is not None
result = ExtractionResult(
original=text,
prefix=text[:prefix_len],
suffix=text[prefix_len:],
generated=generated,
exact_match=exact_match,
contains_email=contains_email,
extracted_email=extracted_email,
repetition_count=rep_count,
prefix_length=prefix_len
)
results.append(result)
except Exception as e:
logger.error(f"[ERROR] Processing {item_id}: {e}")
continue
return results
def analyze_by_repetition(
results: List[ExtractionResult]
) -> Dict:
"""
按重复次数分析记忆提取成功率
核心交付物:重复5次的提取率>50%的验证
"""
# 按重复次数分组
by_repetition = defaultdict(lambda: {'total': 0, 'exact_match': 0, 'contains_email': 0})
for r in results:
rep = r.repetition_count
by_repetition[rep]['total'] += 1
if r.exact_match:
by_repetition[rep]['exact_match'] += 1
if r.contains_email:
by_repetition[rep]['contains_email'] += 1
# 计算比率
analysis = {}
for rep, counts in sorted(by_repetition.items()):
total = counts['total']
if total > 0:
analysis[int(rep)] = {
'total_samples': total,
'exact_match_rate': counts['exact_match'] / total,
'email_recovery_rate': counts['contains_email'] / total,
'exact_matches': counts['exact_match'],
'email_recoveries': counts['contains_email']
}
return analysis
def visualize_extraction_results(
analysis: Dict,
results: List[ExtractionResult],
output_dir: str
):
"""
可视化:记忆率 vs 重复次数,Email恢复率等
"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# 准备数据
reps = sorted(analysis.keys())
exact_rates = [analysis[r]['exact_match_rate'] for r in reps]
email_rates = [analysis[r]['email_recovery_rate'] for r in reps]
# 1. 记忆率 vs 重复次数
ax1 = axes[0, 0]
ax1.plot(reps, exact_rates, marker='o', linewidth=2, markersize=8, color='#e74c3c', label='Exact Match Rate')
ax1.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='50% Threshold')
ax1.set_xlabel('Repetition Count (r)')
ax1.set_ylabel('Extraction Success Rate')
ax1.set_title('Memorization Rate vs Repetition Count')
ax1.grid(True, alpha=0.3)
ax1.legend()
# 检查是否满足>50%@r=5的要求
if 5 in analysis:
rate_at_5 = analysis[5]['exact_match_rate']
color = 'green' if rate_at_5 > 0.5 else 'red'
ax1.scatter([5], [rate_at_5], s=200, c=color, marker='*', zorder=5, edgecolors='black', linewidths=1)
ax1.text(5, rate_at_5, f' r=5: {rate_at_5:.1%}', va='bottom', ha='left', fontweight='bold')
# 2. Email恢复率
ax2 = axes[0, 1]
ax2.plot(reps, email_rates, marker='s', linewidth=2, markersize=8, color='#3498db', label='Email Recovery Rate')
ax2.set_xlabel('Repetition Count (r)')
ax2.set_ylabel('Email Recovery Rate')
ax2.set_title('Email Extraction Success Rate')
ax2.grid(True, alpha=0.3)
ax2.legend()
# 3. 生成长度分布(散点)
ax3 = axes[1, 0]
rep_counts = [r.repetition_count for r in results]
gen_lengths = [len(r.generated) for r in results]
colors = ['red' if r.exact_match else 'blue' for r in results]
ax3.scatter(rep_counts, gen_lengths, alpha=0.6, c=colors, s=50)
ax3.set_xlabel('Repetition Count')
ax3.set_ylabel('Generated Text Length')
ax3.set_title('Generation Length vs Repetition (Red=Exact Match)')
ax3.grid(True, alpha=0.3)
# 4. 前缀长度与成功率关系
ax4 = axes[1, 1]
prefix_lens = [r.prefix_length for r in results]
success = [1 if r.exact_match else 0 for r in results]
# 分箱统计
bins = np.linspace(min(prefix_lens), max(prefix_lens), 10)
bin_centers = (bins[:-1] + bins[1:]) / 2
bin_success = []
for i in range(len(bins)-1):
mask = (np.array(prefix_lens) >= bins[i]) & (np.array(prefix_lens) < bins[i+1])
if mask.sum() > 0:
bin_success.append(np.mean(np.array(success)[mask]))
else:
bin_success.append(0)
ax4.plot(bin_centers, bin_success, marker='d', color='green', linewidth=2)
ax4.set_xlabel('Prefix Length (tokens)')
ax4.set_ylabel('Exact Match Rate')
ax4.set_title('Success Rate vs Prefix Length')
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'extraction_analysis.png'), dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"[VIZ] Saved visualization to {output_dir}/extraction_analysis.png")
def generate_extraction_report(
analysis: Dict,
results: List[ExtractionResult],
output_dir: str,
model_name: str
):
"""
生成详细的提取报告,包括关键发现
"""
report = {
'model': model_name,
'total_samples': len(results),
'overall_exact_match_rate': sum(1 for r in results if r.exact_match) / len(results),
'overall_email_recovery_rate': sum(1 for r in results if r.contains_email) / len(results),
'by_repetition': analysis,
'key_findings': []
}
# 关键发现1:重复5次检查
if 5 in analysis:
rate_5 = analysis[5]['exact_match_rate']
if rate_5 > 0.5:
report['key_findings'].append(
f"CRITICAL: Repetition count r=5 achieves {rate_5:.1%} extraction rate (>50% threshold)"
)
else:
report['key_findings'].append(
f"WARNING: Repetition count r=5 only achieves {rate_5:.1%} extraction rate (below 50% threshold)"
)
# 关键发现2:趋势分析
if len(analysis) > 1:
reps = sorted(analysis.keys())
rates = [analysis[r]['exact_match_rate'] for r in reps]
# 简单线性回归看趋势
if rates[-1] > rates[0]:
report['key_findings'].append(
f"Memorization rate increases with repetition (from {rates[0]:.1%} at r={reps[0]} to {rates[-1]:.1%} at r={reps[-1]})"
)
# 保存报告
report_file = os.path.join(output_dir, 'extraction_report.json')
with open(report_file, 'w') as f:
json.dump(report, f, indent=2)
# 文本报告
text_report = f"""
EXTRACTABLE MEMORIZATION ANALYSIS REPORT
========================================
Model: {model_name}
Total Samples: {report['total_samples']}
OVERALL STATISTICS:
- Exact Match Rate: {report['overall_exact_match_rate']:.2%}
- Email Recovery Rate: {report['overall_email_recovery_rate']:.2%}
BY REPETITION COUNT:
"""
for rep in sorted(analysis.keys()):
data = analysis[rep]
text_report += f"""
r = {rep}:
- Samples: {data['total_samples']}
- Exact Match: {data['exact_match_rate']:.2%} ({data['exact_matches']}/{data['total_samples']})
- Email Recovery: {data['email_recovery_rate']:.2%} ({data['email_recoveries']}/{data['total_samples']})
"""
text_report += "\nKEY FINDINGS:\n"
for finding in report['key_findings']:
text_report += f"- {finding}\n"
text_file = os.path.join(output_dir, 'extraction_report.txt')
with open(text_file, 'w') as f:
f.write(text_report)
logger.info(f"[SAVE] Report saved to {report_file} and {text_file}")
return report
def main():
parser = argparse.ArgumentParser(description='Extractable Memorization Testing')
parser.add_argument('--model_name', type=str, default='EleutherAI/pythia-1.4b',
help='Pythia model name')
parser.add_argument('--test_data', type=str, required=True,
help='JSON file with test samples and repetition counts')
parser.add_argument('--prefix_ratio', type=float, default=0.5,
help='Ratio of text to use as prefix (0.0-1.0)')
parser.add_argument('--max_new_tokens', type=int, default=50,
help='Maximum tokens to generate')
parser.add_argument('--decoding_strategy', type=str, default='greedy',
choices=['greedy', 'sample'],
help='Decoding strategy')
parser.add_argument('--temperature', type=float, default=1.0,
help='Sampling temperature (if strategy=sample)')
parser.add_argument('--output_dir', type=str, default='./extraction_results',
help='Output directory')
parser.add_argument('--device', type=str, default='auto',
help='Device (cuda/cpu/auto)')
args = parser.parse_args()
# 设置设备
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
os.makedirs(args.output_dir, exist_ok=True)
logger.info(f"[INIT] Loading model: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = GPTNeoXForCausalLM.from_pretrained(args.model_name)
model.to(device)
model.eval()
# 加载测试数据
with open(args.test_data, 'r') as f:
test_data = json.load(f)
logger.info(f"[DATA] Loaded {len(test_data)} test samples")
# 执行提取测试
logger.info(f"[EXTRACT] Starting extraction with prefix_ratio={args.prefix_ratio}, strategy={args.decoding_strategy}")
results = test_extractable_memorization(
model, tokenizer, test_data, args.prefix_ratio,
args.max_new_tokens, device, args.decoding_strategy
)
# 分析
logger.info("[ANALYZE] Analyzing results by repetition count...")
analysis = analyze_by_repetition(results)
# 可视化
visualize_extraction_results(analysis, results, args.output_dir)
# 生成报告
report = generate_extraction_report(analysis, results, args.output_dir, args.model_name)
# 打印摘要
print("\n" + "="*60)
print("EXTRACTION TEST SUMMARY")
print("="*60)
print(f"Model: {args.model_name}")
print(f"Total samples tested: {len(results)}")
print(f"Overall exact match: {report['overall_exact_match_rate']:.2%}")
print("\nBy repetition count:")
for rep in sorted(analysis.keys())[:5]: # 显示前5个
rate = analysis[rep]['exact_match_rate']
marker = " <-- TARGET" if rep == 5 and rate > 0.5 else ""
print(f" r={rep}: {rate:.2%}{marker}")
print("="*60)
if __name__ == '__main__':
main()
脚本3:Pythia多尺度记忆率实验
内容描述:在Pythia模型系列(70M-12B)上执行跨尺度的记忆化率对比实验,测量不同参数量模型对重复数据的记忆敏感度。生成记忆率-重复次数-模型规模的3D关系可视化。
使用方式 :python 03_pythia_multiscale_experiment.py --model_sizes 70m,160m,1.4b,2.8b,6.9b,12b --test_data scaled_dataset.json
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Script 3: Pythia Multi-Scale Memorization Experiment
Pythia模型系列(70M-12B)跨尺度记忆化率评估系统
Usage:
python 03_pythia_multiscale_experiment.py \
--test_data data/scaled_canaries.json \
--model_sizes 70m,160m,410m,1.4b,2.8b,6.9b,12b \
--repetition_levels 1,2,3,4,5,10,20 \
--output_dir ./pythia_scale_results
"""
import os
import json
import argparse
import logging
from typing import List, Dict, Tuple
from collections import defaultdict
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
from transformers import GPTNeoXForCausalLM, AutoTokenizer
# 导入本地模块
from 02_extractable_memorization import test_extractable_memorization, analyze_by_repetition
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
PYTHIA_MODELS = {
'70m': 'EleutherAI/pythia-70m',
'160m': 'EleutherAI/pythia-160m',
'410m': 'EleutherAI/pythia-410m',
'1b': 'EleutherAI/pythia-1b',
'1.4b': 'EleutherAI/pythia-1.4b',
'2.8b': 'EleutherAI/pythia-2.8b',
'6.9b': 'EleutherAI/pythia-6.9b',
'12b': 'EleutherAI/pythia-12b'
}
def get_model_params(size_key: str) -> int:
"""获取模型参数量(百万)"""
params_map = {
'70m': 70, '160m': 160, '410m': 410, '1b': 1000,
'1.4b': 1400, '2.8b': 2800, '6.9b': 6900, '12b': 12000
}
return params_map.get(size_key, 0)
def run_single_model_experiment(
model_key: str,
test_data: List[Dict],
device: torch.device,
prefix_ratio: float = 0.5,
max_new_tokens: int = 50
) -> Dict:
"""
对单个Pythia模型执行记忆提取实验
"""
model_name = PYTHIA_MODELS[model_key]
logger.info(f"\n[EXPERIMENT] Testing {model_name} ({model_key})")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPTNeoXForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()
# 执行提取测试
results = test_extractable_memorization(
model, tokenizer, test_data, prefix_ratio, max_new_tokens, device, 'greedy'
)
# 分析
analysis = analyze_by_repetition(results)
# 清理GPU内存
del model
torch.cuda.empty_cache()
return {
'model_key': model_key,
'params': get_model_params(model_key),
'analysis': analysis,
'overall_success': sum(1 for r in results if r.exact_match) / len(results),
'raw_results': results
}
except Exception as e:
logger.error(f"[ERROR] Failed to test {model_name}: {e}")
return {
'model_key': model_key,
'error': str(e)
}
def aggregate_cross_scale_results(
all_results: List[Dict]
) -> Dict:
"""
聚合跨尺度实验结果
"""
aggregated = {
'models': [],
'repetition_levels': set(),
'memorization_matrix': [] # [model_idx, rep_idx] -> rate
}
# 收集所有重复级别
for result in all_results:
if 'analysis' in result:
aggregated['repetition_levels'].update(result['analysis'].keys())
aggregated['repetition_levels'] = sorted(aggregated['repetition_levels'])
rep_index = {r: i for i, r in enumerate(aggregated['repetition_levels'])}
# 构建矩阵
for result in all_results:
if 'analysis' not in result:
continue
model_data = {
'key': result['model_key'],
'params': result['params'],
'rates': [0.0] * len(aggregated['repetition_levels'])
}
for rep, data in result['analysis'].items():
idx = rep_index[rep]
model_data['rates'][idx] = data['exact_match_rate']
aggregated['models'].append(model_data)
aggregated['memorization_matrix'].append(model_data['rates'])
aggregated['memorization_matrix'] = np.array(aggregated['memorization_matrix'])
return aggregated
def visualize_cross_scale_results(
aggregated: Dict,
output_dir: str
):
"""
跨尺度可视化:记忆率 vs 模型规模 vs 重复次数
"""
models = aggregated['models']
reps = aggregated['repetition_levels']
matrix = aggregated['memorization_matrix']
fig = plt.figure(figsize=(16, 12))
# 1. 2D热图:模型 vs 重复次数
ax1 = fig.add_subplot(2, 2, 1)
im = ax1.imshow(matrix, cmap='YlOrRd', aspect='auto', vmin=0, vmax=1)
ax1.set_xticks(range(len(reps)))
ax1.set_xticklabels([f'r={r}' for r in reps])
ax1.set_yticks(range(len(models)))
ax1.set_yticklabels([f"{m['key']}\n({m['params']}M)" for m in models])
ax1.set_xlabel('Repetition Count')
ax1.set_ylabel('Model Size')
ax1.set_title('Memorization Rate Heatmap')
plt.colorbar(im, ax=ax1, label='Extraction Success Rate')
# 添加数值标注
for i in range(len(models)):
for j in range(len(reps)):
text = ax1.text(j, i, f'{matrix[i, j]:.2f}',
ha="center", va="center", color="black" if matrix[i, j] < 0.5 else "white",
fontsize=8)
# 2. 折线图:不同模型的记忆率曲线
ax2 = fig.add_subplot(2, 2, 2)
for i, model in enumerate(models):
ax2.plot(reps, matrix[i], marker='o', linewidth=2,
label=f"{model['key']} ({model['params']}M)", markersize=6)
ax2.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='50% Threshold')
ax2.set_xlabel('Repetition Count (r)')
ax2.set_ylabel('Memorization Rate')
ax2.set_title('Memorization Rate by Model Scale')
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
ax2.grid(True, alpha=0.3)
# 3. 3D曲面图
ax3 = fig.add_subplot(2, 2, 3, projection='3d')
X, Y = np.meshgrid(range(len(reps)), range(len(models)))
surf = ax3.plot_surface(X, Y, matrix, cmap='viridis', alpha=0.8)
ax3.set_xticks(range(len(reps)))
ax3.set_xticklabels([str(r) for r in reps])
ax3.set_yticks(range(len(models)))
ax3.set_yticklabels([m['key'] for m in models])
ax3.set_xlabel('Repetition Count')
ax3.set_ylabel('Model Scale')
ax3.set_zlabel('Memorization Rate')
ax3.set_title('3D Memorization Landscape')
fig.colorbar(surf, ax=ax3, shrink=0.5, aspect=5)
# 4. 特定重复次数(r=5)的柱状图比较
ax4 = fig.add_subplot(2, 2, 4)
if 5 in reps:
idx_5 = reps.index(5)
rates_at_5 = matrix[:, idx_5]
colors = ['green' if r > 0.5 else 'red' for r in rates_at_5]
bars = ax4.bar(range(len(models)), rates_at_5, color=colors, alpha=0.7, edgecolor='black')
ax4.axhline(y=0.5, color='red', linestyle='--', linewidth=2, label='Target 50%')
ax4.set_xticks(range(len(models)))
ax4.set_xticklabels([m['key'] for m in models], rotation=45, ha='right')
ax4.set_ylabel('Extraction Rate at r=5')
ax4.set_title('Memorization Rate at Repetition Count = 5')
ax4.legend()
# 添加数值标签
for bar, rate in zip(bars, rates_at_5):
height = bar.get_height()
ax4.text(bar.get_x() + bar.get_width()/2., height,
f'{rate:.1%}', ha='center', va='bottom', fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'cross_scale_analysis.png'), dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"[VIZ] Cross-scale visualization saved")
def generate_scale_report(
aggregated: Dict,
all_results: List[Dict],
output_dir: str
):
"""
生成跨尺度实验报告,特别关注r=5时的>50%要求
"""
report = {
'experiment_type': 'Pythia Multi-Scale Memorization',
'models_tested': [r['model_key'] for r in all_results if 'analysis' in r],
'repetition_levels': aggregated['repetition_levels'],
'key_findings': []
}
# 检查每个模型在r=5时的表现
if 5 in aggregated['repetition_levels']:
idx_5 = aggregated['repetition_levels'].index(5)
report['r=5_analysis'] = {}
for model in aggregated['models']:
rate = model['rates'][idx_5]
report['r=5_analysis'][model['key']] = {
'params': model['params'],
'extraction_rate': rate,
'meets_threshold': rate > 0.5
}
if rate > 0.5:
report['key_findings'].append(
f"Model {model['key']} ({model['params']}M params): "
f"r=5 extraction rate = {rate:.1%} (EXCEEDS 50% threshold)"
)
else:
report['key_findings'].append(
f"Model {model['key']} ({model['params']}M params): "
f"r=5 extraction rate = {rate:.1%} (below 50% threshold)"
)
# 规模趋势分析
if len(aggregated['models']) > 1:
params = [m['params'] for m in aggregated['models']]
# 使用r=5的数据或平均数据
if 5 in aggregated['repetition_levels']:
rates = [m['rates'][aggregated['repetition_levels'].index(5)] for m in aggregated['models']]
else:
rates = [np.mean(m['rates']) for m in aggregated['models']]
# 计算相关系数
correlation = np.corrcoef(params, rates)[0, 1] if len(params) > 1 else 0
report['scale_correlation'] = float(correlation)
report['key_findings'].append(
f"Correlation between model size and memorization rate: {correlation:.3f} "
f"({'Strong' if abs(correlation) > 0.7 else 'Moderate' if abs(correlation) > 0.4 else 'Weak'})"
)
# 保存报告
report_file = os.path.join(output_dir, 'multiscale_report.json')
with open(report_file, 'w') as f:
json.dump(report, f, indent=2)
# 文本报告
text_lines = [
"PYTHIA MULTI-SCALE MEMORIZATION EXPERIMENT REPORT",
"=" * 60,
"",
"MODELS TESTED:",
]
for model in aggregated['models']:
text_lines.append(f" - {model['key']}: {model['params']}M parameters")
text_lines.extend([
"",
"REPETITION LEVELS TESTED:",
f" {aggregated['repetition_levels']}",
"",
"KEY FINDINGS (r=5 Target Analysis):",
])
for finding in report['key_findings']:
text_lines.append(f" • {finding}")
text_lines.extend([
"",
"MEMORIZATION MATRIX [Model x Repetition]:",
" " + " ".join([f"r={r:>2}" for r in aggregated['repetition_levels']])
])
for model in aggregated['models']:
rates_str = " ".join([f"{r:>4.2f}" for r in model['rates']])
text_lines.append(f" {model['key']:>6} [{rates_str}]")
text_file = os.path.join(output_dir, 'multiscale_report.txt')
with open(text_file, 'w') as f:
f.write('\n'.join(text_lines))
logger.info(f"[SAVE] Reports saved to {report_file} and {text_file}")
return report
def main():
parser = argparse.ArgumentParser(description='Pythia Multi-Scale Memorization Experiment')
parser.add_argument('--test_data', type=str, required=True,
help='JSON file with canary samples and repetition metadata')
parser.add_argument('--model_sizes', type=str, default='70m,160m,410m,1.4b,2.8b,6.9b,12b',
help='Comma-separated list of Pythia model sizes to test')
parser.add_argument('--repetition_levels', type=str, default='1,2,3,4,5,10,20',
help='Comma-separated repetition counts to analyze')
parser.add_argument('--prefix_ratio', type=float, default=0.5,
help='Prefix ratio for extraction')
parser.add_argument('--output_dir', type=str, default='./pythia_scale_results',
help='Output directory')
parser.add_argument('--device', type=str, default='auto',
help='Device (cuda/cpu/auto)')
args = parser.parse_args()
# 解析参数
model_sizes = [s.strip() for s in args.model_sizes.split(',')]
target_reps = [int(r.strip()) for r in args.repetition_levels.split(',')]
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(args.device)
os.makedirs(args.output_dir, exist_ok=True)
logger.info(f"[INIT] Starting multi-scale experiment")
logger.info(f"[INIT] Testing models: {model_sizes}")
logger.info(f"[INIT] Target repetition levels: {target_reps}")
# 加载数据
with open(args.test_data, 'r') as f:
all_data = json.load(f)
# 按重复次数过滤数据(确保每个r有足够的样本)
filtered_data = [d for d in all_data if d.get('repetition_count', 1) in target_reps]
logger.info(f"[DATA] Loaded {len(filtered_data)} samples matching target repetition levels")
if len(filtered_data) == 0:
logger.error("[ERROR] No samples found for specified repetition levels")
return
# 逐个模型测试
all_results = []
for size in model_sizes:
if size not in PYTHIA_MODELS:
logger.warning(f"[SKIP] Unknown model size: {size}")
continue
result = run_single_model_experiment(
size, filtered_data, device, args.prefix_ratio
)
all_results.append(result)
# 保存中间结果
intermediate_file = os.path.join(args.output_dir, f'result_{size}.json')
with open(intermediate_file, 'w') as f:
# 不保存raw_results以节省空间
save_data = {k: v for k, v in result.items() if k != 'raw_results'}
json.dump(save_data, f, indent=2)
# 聚合分析
logger.info("[ANALYZE] Aggregating cross-scale results...")
aggregated = aggregate_cross_scale_results(all_results)
# 可视化
visualize_cross_scale_results(aggregated, args.output_dir)
# 生成报告
report = generate_scale_report(aggregated, all_results, args.output_dir)
# 最终摘要
print("\n" + "="*70)
print("MULTI-SCALE EXPERIMENT COMPLETE")
print("="*70)
print(f"Results saved to: {args.output_dir}")
print("\nDeliverable Check (r=5 extraction rate > 50%):")
if 'r=5_analysis' in report:
for model, data in report['r=5_analysis'].items():
status = "✓ PASS" if data['meets_threshold'] else "✗ FAIL"
print(f" {status} {model:>6} ({data['params']:>4}M): {data['extraction_rate']:>6.1%}")
print("="*70)
if __name__ == '__main__':
main()