【文本分类】bert二分类

python 复制代码
import os
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm

# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }


# 训练函数
def train_model(model, train_loader, optimizer, device, num_epochs=3):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1} Loss: {total_loss / len(train_loader)}")


# 评估函数
def evaluate_model(model, val_loader, device):
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    report = classification_report(true_labels, predictions)
    print(f"Validation Accuracy: {accuracy}")
    print("Classification Report:")
    print(report)


# 模型保存函数
def save_model(model, tokenizer, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")


# 模型加载函数
def load_model(output_dir, device):
    tokenizer = BertTokenizer.from_pretrained(output_dir)
    model = BertForSequenceClassification.from_pretrained(output_dir)
    model.to(device)
    print(f"Model loaded from {output_dir}")
    return model, tokenizer


# 推理预测函数
def predict(texts, model, tokenizer, device, max_length=128):
    model.eval()
    encodings = tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    input_ids = encodings["input_ids"].to(device)
    attention_mask = encodings["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1).cpu().numpy()
        predictions = torch.argmax(logits, dim=1).cpu().numpy()

    return predictions, probabilities


# 主函数
def main():
    # 配置参数
    config = {
        "train_batch_size": 16,
        "val_batch_size": 16,
        "learning_rate": 5e-5,
        "num_epochs": 5,
        "max_length": 128,
        "device_id": 7,  # 指定 GPU ID
        "model_dir": "model",
        "local_model_path": "roberta_tiny_model",  # 指定本地模型路径,如果为 None 则使用预训练模型
        "pretrained_model_name": "uer/chinese_roberta_L-12_H-128",  # 预训练模型名称
    }

    # 设置设备
    device = torch.device(f"cuda:{config['device_id']}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 加载分词器和模型
    tokenizer = BertTokenizer.from_pretrained(config["local_model_path"])
    model = BertForSequenceClassification.from_pretrained(config["local_model_path"], num_labels=2)
    model.to(device)

    # 示例数据
    train_texts = ["This is a great product!", "I hate this service."]
    train_labels = [1, 0]
    val_texts = ["Awesome experience.", "Terrible product."]
    val_labels = [1, 0]

    # 创建数据集和数据加载器
    train_dataset = CustomDataset(train_texts, train_labels, tokenizer, config["max_length"])
    val_dataset = CustomDataset(val_texts, val_labels, tokenizer, config["max_length"])
    train_loader = DataLoader(train_dataset, batch_size=config["train_batch_size"], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config["val_batch_size"])

    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=config["learning_rate"])

    # 训练模型
    train_model(model, train_loader, optimizer, device, num_epochs=config["num_epochs"])

    # 评估模型
    evaluate_model(model, val_loader, device)

    # 保存模型
    save_model(model, tokenizer, config["model_dir"])

    # 加载模型
    loaded_model, loaded_tokenizer = load_model(config["model_dir"], "cpu")

    # 推理预测
    new_texts = ["I love this!", "It's the worst."]
    predictions, probabilities = predict(new_texts, loaded_model, loaded_tokenizer,  "cpu")
    for text, pred, prob in zip(new_texts, predictions, probabilities):
        print(f"Text: {text}")
        print(f"Predicted Label: {pred} (Probability: {prob})")


if __name__ == "__main__":
    main()
相关推荐
天行健,君子而铎2 天前
自适应分类·高准确率·可视化易用——运营商数据分类分级解决方案
大数据·分类
装不满的克莱因瓶2 天前
了解多标签图像分类方法——从Sigmoid输出到真实世界复杂视觉理解
人工智能·pytorch·python·深度学习·机器学习·分类·数据挖掘
装不满的克莱因瓶2 天前
掌握语义分割经典模型 FCN——从像素分类到端到端分割的奠基之作
人工智能·python·深度学习·算法·机器学习·分类·数据挖掘
雷工笔记2 天前
MES系列51-人防门行业 MES 质检分类体系
人工智能·分类·数据挖掘
2401_885665193 天前
从零搭建CNN到迁移学习:以食物分类为例深入理解PyTorch图像分类实战
人工智能·pytorch·深度学习·分类·cnn·迁移学习
百胜软件@百胜软件3 天前
货品“精”营:ABC-XYZ分类如何驱动鞋服全渠道库存效率革命?
人工智能·分类·数据挖掘·零售数字化·数智中台·珠宝行业
zcg19423 天前
分类中的样本不平衡问题——Asymmetric Loss
人工智能·分类·数据挖掘
daly5203 天前
人工智能专业有哪些?2026高考报考指南(专业分类 + 课程 + 就业全解析)
人工智能·分类·高考
DXM05214 天前
第11期| 遥感图像分类模型:ResNet_DenseNet原理+实战训练
人工智能·python·深度学习·机器学习·分类·数据挖掘·ageo
酉鬼女又兒4 天前
零基础入门IPv4地址:从基本概念、分类编址、子网划分到无分类编址与应用规划全解
网络·网络协议·计算机网络·考研·职场和发展·分类·智能路由器