dayy43

@z浙大疏锦行

复制代码
"""
实验5: 预训练BERT-mini + Linear微调 (exp5)
==========================================
参考Rank1方案, 使用自定义预训练的BERT-mini,
搭配 Linear 分类头进行文本分类。

使用分段策略: 长文本分段输入, logits叠加

用法:
  python src/train_bert_cls.py
"""
import os, json, time, random, warnings, math
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertConfig, BertModel
from sklearn.metrics import f1_score, accuracy_score, classification_report

BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_DIR = os.path.join(BASE_DIR, 'data')
LOGS_DIR = os.path.join(BASE_DIR, 'logs')
MODELS_DIR = os.path.join(BASE_DIR, 'models')
SUBMISSIONS_DIR = os.path.join(BASE_DIR, 'submissions')
for d in [LOGS_DIR, MODELS_DIR, SUBMISSIONS_DIR]:
    os.makedirs(d, exist_ok=True)

LABEL_MAP = {0: '科技', 1: '股票', 2: '体育', 3: '娱乐', 4: '时政',
             5: '社会', 6: '教育', 7: '财经', 8: '家居', 9: '游戏',
             10: '房产', 11: '时尚', 12: '彩票', 13: '星座'}

SEED = 42
CLS_TOKEN = 7999
PAD_TOKEN = 7998
VOCAB_SIZE = 8000
MAX_LEN = 4096
NUM_CLASSES = 14
BATCH_SIZE = 4
GRAD_ACCUM = 8
LR_BERT = 2e-4
LR_HEAD = 1e-3
EPOCHS = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# ========================= Dataset =========================

class ClassificationDataset(Dataset):
    """
    segment=True: 训练用, 长文本按seg_len分段, 每段各带原始label, 最后不足一段从末尾取整段
    segment=False: 验证用, 超长文本直接截断, 保持样本数不变
    """
    def __init__(self, texts, labels=None, max_len=MAX_LEN, segment=True):
        self.max_len = max_len
        seg_len = max_len - 1
        self.segments = []
        self.seg_labels = []

        for i, text in enumerate(texts):
            toks = text.split()
            label = labels[i] if labels is not None else None
            if not segment or len(toks) <= seg_len:
                self.segments.append(toks[:seg_len])
                if label is not None:
                    self.seg_labels.append(label)
            else:
                for start in range(0, len(toks), seg_len):
                    chunk = toks[start:start + seg_len]
                    if len(chunk) == seg_len:
                        self.segments.append(chunk)
                        if label is not None:
                            self.seg_labels.append(label)
                    else:
                        tail = toks[-seg_len:]
                        self.segments.append(tail)
                        if label is not None:
                            self.seg_labels.append(label)
                        break

        if labels is not None and segment:
            print(f"  分段扩充: {len(texts)} -> {len(self.segments)} 条训练样本")

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        tokens = [int(t) for t in self.segments[idx]]
        tokens = [CLS_TOKEN] + tokens
        if len(tokens) < self.max_len:
            tokens = tokens + [PAD_TOKEN] * (self.max_len - len(tokens))
        else:
            tokens = tokens[:self.max_len]
        x = torch.tensor(tokens, dtype=torch.long)
        attn_mask = (x != PAD_TOKEN).long()
        if self.seg_labels:
            return x, attn_mask, torch.tensor(self.seg_labels[idx], dtype=torch.long)
        return x, attn_mask


class PredictionDataset(Dataset):
    """测试时分段: 长文本拆分为多段, 每段独立输入, logits叠加"""
    def __init__(self, texts, max_len=MAX_LEN, max_segments=10):
        self.texts = texts
        self.max_len = max_len
        self.max_segments = max_segments

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        tokens = [int(t) for t in self.texts[idx].split()]
        seg_size = self.max_len - 1

        segments = []
        if len(tokens) <= seg_size:
            seg = [CLS_TOKEN] + tokens + [PAD_TOKEN] * (self.max_len - len(tokens) - 1)
            segments.append(seg)
        else:
            for i in range(0, len(tokens), seg_size):
                chunk = tokens[i:i + seg_size]
                if len(chunk) == seg_size:
                    segments.append([CLS_TOKEN] + chunk)
                else:
                    tail = tokens[-seg_size:]
                    seg = [CLS_TOKEN] + tail
                    seg = seg + [PAD_TOKEN] * (self.max_len - len(seg))
                    segments.append(seg)
                    break
                if len(segments) >= self.max_segments:
                    break

        actual = len(segments)
        while len(segments) < self.max_segments:
            segments.append([CLS_TOKEN] + [PAD_TOKEN] * (self.max_len - 1))

        return torch.tensor(segments, dtype=torch.long), actual

# ========================= Model =========================

class BertLinearClassifier(nn.Module):
    def __init__(self, bert_path):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_path)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, NUM_CLASSES)

    def forward(self, input_ids, attention_mask=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
        cls_out = out[:, 0, :]
        return self.fc(self.dropout(cls_out))

# ========================= Training =========================

def get_class_weights(labels):
    counts = np.bincount(labels, minlength=NUM_CLASSES)
    weights = 1.0 / (counts + 1e-6)
    weights = weights / weights.sum() * NUM_CLASSES
    return torch.tensor(weights, dtype=torch.float32)


def train_epoch(model, loader, optimizer, criterion, device, scaler, scheduler,
                grad_accum=GRAD_ACCUM):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    optimizer.zero_grad()

    log_interval = grad_accum * 25
    interval_loss = 0
    interval_samples = 0

    for step, (x, mask, y) in enumerate(loader):
        x, mask, y = x.to(device), mask.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            logits = model(x, attention_mask=mask)
            loss = criterion(logits, y) / grad_accum

        scaler.scale(loss).backward()

        if (step + 1) % grad_accum == 0 or (step + 1) == len(loader):
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()

        batch_loss = loss.item() * grad_accum
        total_loss += batch_loss * x.size(0)
        interval_loss += batch_loss * x.size(0)
        interval_samples += x.size(0)
        all_preds.extend(logits.argmax(1).cpu().numpy())
        all_labels.extend(y.cpu().numpy())

        if (step + 1) % log_interval == 0:
            recent_loss = interval_loss / interval_samples
            lr_now = scheduler.get_last_lr()[0]
            print(f"  Step {step+1}/{len(loader)} | loss={recent_loss:.4f} | lr={lr_now:.6f}")
            interval_loss = 0
            interval_samples = 0

    f1 = f1_score(all_labels, all_preds, average='macro')
    return total_loss / len(loader.dataset), f1

@torch.no_grad()
def eval_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    for x, mask, y in loader:
        x, mask, y = x.to(device), mask.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            logits = model(x, attention_mask=mask)
            loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0)
        all_preds.extend(logits.argmax(1).cpu().numpy())
        all_labels.extend(y.cpu().numpy())
    f1 = f1_score(all_labels, all_preds, average='macro')
    acc = accuracy_score(all_labels, all_preds)
    return total_loss / len(loader.dataset), f1, acc, all_preds, all_labels

@torch.no_grad()
def predict_with_segments(model, loader, device):
    """分段推理: 每个样本多段输入, logits叠加"""
    model.eval()
    all_preds = []
    for segments_batch, actual_counts in loader:
        B = segments_batch.size(0)
        for i in range(B):
            n_seg = actual_counts[i].item()
            segs = segments_batch[i, :n_seg].to(device)
            mask = (segs != PAD_TOKEN).long().to(device)
            with torch.cuda.amp.autocast():
                logits = model(segs, attention_mask=mask)
            combined = logits.sum(dim=0)
            all_preds.append(combined.argmax().cpu().item())
    return all_preds

# ========================= Main =========================

def main():
    exp_name = 'bert_linear'

    pretrained_path = os.path.join(MODELS_DIR, 'bert_pretrained', 'final')
    if not os.path.exists(os.path.join(pretrained_path, 'config.json')):
        print(f"[错误] 预训练模型不存在: {pretrained_path}")
        print("请先运行 python src/pretrain_bert.py")
        return

    print("=" * 60)
    print("实验5: BERT-mini + Linear 微调")
    print(f"预训练模型: {pretrained_path}")
    print(f"max_len={MAX_LEN}, device={DEVICE}")
    print("=" * 60)

    train_df = pd.read_csv(os.path.join(DATA_DIR, 'train_split.csv'), sep='\t')
    val_df = pd.read_csv(os.path.join(DATA_DIR, 'val_split.csv'), sep='\t')
    test_df = pd.read_csv(os.path.join(DATA_DIR, 'test_a.csv'), sep='\t')

    train_ds = ClassificationDataset(train_df['text'].tolist(), train_df['label'].tolist(), segment=True)
    val_ds = ClassificationDataset(val_df['text'].tolist(), val_df['label'].tolist(), segment=False)
    test_ds = PredictionDataset(test_df['text'].tolist())

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=2, pin_memory=True, persistent_workers=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=2, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

    model = BertLinearClassifier(pretrained_path).to(DEVICE)
    param_count = sum(p.numel() for p in model.parameters())
    print(f"模型参数量: {param_count/1e6:.2f}M")

    bert_params = list(model.bert.parameters())
    head_params = [p for n, p in model.named_parameters() if not n.startswith('bert')]
    optimizer = torch.optim.AdamW([
        {'params': bert_params, 'lr': LR_BERT},
        {'params': head_params, 'lr': LR_HEAD}
    ], weight_decay=0.01)

    class_weights = get_class_weights(train_df['label'].values).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    scaler = torch.cuda.amp.GradScaler()

    opt_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM)
    total_opt_steps = opt_steps_per_epoch * EPOCHS
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_opt_steps)

    best_f1 = 0
    history = []

    for epoch in range(1, EPOCHS + 1):
        t0 = time.time()
        train_loss, train_f1 = train_epoch(model, train_loader, optimizer, criterion,
                                            DEVICE, scaler, scheduler)
        val_loss, val_f1, val_acc, val_preds, val_labels = eval_epoch(model, val_loader, criterion, DEVICE)
        elapsed = time.time() - t0

        print(f"Epoch {epoch:2d}/{EPOCHS} | "
              f"train_loss={train_loss:.4f} train_f1={train_f1:.4f} | "
              f"val_loss={val_loss:.4f} val_f1={val_f1:.4f} val_acc={val_acc:.4f} | "
              f"time={elapsed:.1f}s")

        history.append({
            'epoch': epoch, 'train_loss': train_loss, 'train_f1': train_f1,
            'val_loss': val_loss, 'val_f1': val_f1, 'val_acc': val_acc
        })

        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), os.path.join(MODELS_DIR, f'{exp_name}_best.pt'))
            print(f"  -> 保存最佳模型 (val_f1={best_f1:.4f})")

    model.load_state_dict(torch.load(os.path.join(MODELS_DIR, f'{exp_name}_best.pt')))
    _, final_f1, final_acc, val_preds, val_labels = eval_epoch(model, val_loader, criterion, DEVICE)
    per_class = f1_score(val_labels, val_preds, average=None)
    per_class_dict = {LABEL_MAP[i]: float(f'{v:.4f}') for i, v in enumerate(per_class)}

    print(f"\n最终验证集: macro-F1={final_f1:.4f}, accuracy={final_acc:.4f}")
    print(classification_report(val_labels, val_preds,
          target_names=[LABEL_MAP[i] for i in range(14)], digits=4))

    print("测试集分段推理中...")
    test_preds = predict_with_segments(model, test_loader, DEVICE)
    sub_df = pd.DataFrame({'label': test_preds})
    sub_df.to_csv(os.path.join(SUBMISSIONS_DIR, f'submission_{exp_name}.csv'), index=False)

    model.eval()
    val_logits = []
    with torch.no_grad():
        for x, mask, y in val_loader:
            x, mask = x.to(DEVICE), mask.to(DEVICE)
            with torch.cuda.amp.autocast():
                logits = model(x, attention_mask=mask)
            val_logits.append(logits.cpu())
    val_logits = torch.cat(val_logits, dim=0).numpy()
    np.save(os.path.join(MODELS_DIR, f'{exp_name}_val_logits.npy'), val_logits)

    test_logits_list = []
    model.eval()
    with torch.no_grad():
        for segments_batch, actual_counts in test_loader:
            n_seg = actual_counts[0].item()
            segs = segments_batch[0, :n_seg].to(DEVICE)
            mask = (segs != PAD_TOKEN).long().to(DEVICE)
            with torch.cuda.amp.autocast():
                logits = model(segs, attention_mask=mask)
            combined = logits.sum(dim=0)
            test_logits_list.append(combined.cpu().numpy())
    test_logits = np.stack(test_logits_list)
    np.save(os.path.join(MODELS_DIR, f'{exp_name}_test_logits.npy'), test_logits)

    result = {
        'model': 'BERT-mini + Linear',
        'max_len': MAX_LEN, 'batch_size': BATCH_SIZE,
        'lr_bert': LR_BERT, 'lr_head': LR_HEAD, 'epochs': EPOCHS,
        'best_val_f1': float(f'{best_f1:.4f}'),
        'best_val_acc': float(f'{final_acc:.4f}'),
        'per_class_f1': per_class_dict, 'history': history
    }
    with open(os.path.join(LOGS_DIR, 'exp5_bert_linear_results.json'), 'w') as f:
        json.dump(result, f, ensure_ascii=False, indent=2)

    log_csv = os.path.join(LOGS_DIR, 'experiment_log.csv')
    entry = pd.DataFrame([{
        '实验ID': 'exp5', '模型名称': 'BERT-mini+Linear',
        'max_len': MAX_LEN, 'batch_size': BATCH_SIZE,
        'lr': f'{LR_BERT}/{LR_HEAD}', 'epochs': EPOCHS,
        'val_macro_f1': best_f1, 'val_accuracy': final_acc,
        '训练时间(min)': '', '备注': 'pretrained bert-mini + segment inference'
    }])
    entry.to_csv(log_csv, mode='a', header=not os.path.exists(log_csv), index=False)

    print(f"\n[完成] 实验5 BERT-mini+Linear 完毕!")

if __name__ == '__main__':
    main()
相关推荐
睡个好觉(努力提升自己版)2 分钟前
2026_TIP_image_Restoration(最新方法)
人工智能·深度学习·机器学习
郝学胜-神的一滴5 分钟前
系统设计 014:缓存深度实战:如何用 Cache 优雅优化数据库读写?
java·数据库·python·缓存·oracle·php·软件构建
Cloud_Shy61810 分钟前
解读《Effective Python 3rd Edition》:从练气到老魔(第三章 Item 17 - 20)
开发语言·笔记·python
ZHW_AI课题组21 分钟前
使用Stable Diffusion v1.5文本引导与无分类器引导(CFG)算法实现条件生成图片
人工智能·python·算法·机器学习·stable diffusion
盼小辉丶22 分钟前
OpenCV-Python实战(25)——基于深度传感器与凸性分析打造实时手势识别系统
人工智能·python·opencv·计算机视觉
金融大 k26 分钟前
行情数据接入 MCP:Claude Code / Cursor 工具描述怎么写才不踩坑
人工智能·python·websocket·行情 api
code_pgf29 分钟前
CRNN + CTC OCR 原理详解
深度学习·ocr
张彦峰ZYF35 分钟前
深入 LangGraph State:Reducer 是如何让状态“自动合并”的
人工智能·python·大模型·langgraph
夜空繁星vv37 分钟前
widows环境 下使用python开发的仿照Linux的grep的能力
linux·开发语言·python
数学建模导师37 分钟前
【AI生成内容的质量评估】2026中青杯B题26页成品论文重磅更新
人工智能·深度学习·机器学习