【文本分类】bert二分类

python 复制代码
import os
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm

# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }


# 训练函数
def train_model(model, train_loader, optimizer, device, num_epochs=3):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1} Loss: {total_loss / len(train_loader)}")


# 评估函数
def evaluate_model(model, val_loader, device):
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    report = classification_report(true_labels, predictions)
    print(f"Validation Accuracy: {accuracy}")
    print("Classification Report:")
    print(report)


# 模型保存函数
def save_model(model, tokenizer, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")


# 模型加载函数
def load_model(output_dir, device):
    tokenizer = BertTokenizer.from_pretrained(output_dir)
    model = BertForSequenceClassification.from_pretrained(output_dir)
    model.to(device)
    print(f"Model loaded from {output_dir}")
    return model, tokenizer


# 推理预测函数
def predict(texts, model, tokenizer, device, max_length=128):
    model.eval()
    encodings = tokenizer(
        texts,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    input_ids = encodings["input_ids"].to(device)
    attention_mask = encodings["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1).cpu().numpy()
        predictions = torch.argmax(logits, dim=1).cpu().numpy()

    return predictions, probabilities


# 主函数
def main():
    # 配置参数
    config = {
        "train_batch_size": 16,
        "val_batch_size": 16,
        "learning_rate": 5e-5,
        "num_epochs": 5,
        "max_length": 128,
        "device_id": 7,  # 指定 GPU ID
        "model_dir": "model",
        "local_model_path": "roberta_tiny_model",  # 指定本地模型路径,如果为 None 则使用预训练模型
        "pretrained_model_name": "uer/chinese_roberta_L-12_H-128",  # 预训练模型名称
    }

    # 设置设备
    device = torch.device(f"cuda:{config['device_id']}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 加载分词器和模型
    tokenizer = BertTokenizer.from_pretrained(config["local_model_path"])
    model = BertForSequenceClassification.from_pretrained(config["local_model_path"], num_labels=2)
    model.to(device)

    # 示例数据
    train_texts = ["This is a great product!", "I hate this service."]
    train_labels = [1, 0]
    val_texts = ["Awesome experience.", "Terrible product."]
    val_labels = [1, 0]

    # 创建数据集和数据加载器
    train_dataset = CustomDataset(train_texts, train_labels, tokenizer, config["max_length"])
    val_dataset = CustomDataset(val_texts, val_labels, tokenizer, config["max_length"])
    train_loader = DataLoader(train_dataset, batch_size=config["train_batch_size"], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config["val_batch_size"])

    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=config["learning_rate"])

    # 训练模型
    train_model(model, train_loader, optimizer, device, num_epochs=config["num_epochs"])

    # 评估模型
    evaluate_model(model, val_loader, device)

    # 保存模型
    save_model(model, tokenizer, config["model_dir"])

    # 加载模型
    loaded_model, loaded_tokenizer = load_model(config["model_dir"], "cpu")

    # 推理预测
    new_texts = ["I love this!", "It's the worst."]
    predictions, probabilities = predict(new_texts, loaded_model, loaded_tokenizer,  "cpu")
    for text, pred, prob in zip(new_texts, predictions, probabilities):
        print(f"Text: {text}")
        print(f"Predicted Label: {pred} (Probability: {prob})")


if __name__ == "__main__":
    main()
相关推荐
deardao7 小时前
【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 2:基础知识
人工智能·深度学习·神经网络·分类·数据挖掘·极限标签分类
Schwertlilien13 小时前
模式识别-Ch2-分类错误率
人工智能·分类·数据挖掘
行码棋17 小时前
【LLM文本分类微调】骚扰邮件分类
人工智能·分类·数据挖掘
IT古董1 天前
【机器学习】机器学习的基本分类-自监督学习(Self-supervised Learning)
人工智能·学习·机器学习·分类
MorleyOlsen2 天前
【ChatGPT原理与应用开发】第三章:句词分类
chatgpt·分类·数据挖掘
IT古董2 天前
【机器学习】机器学习的基本分类-自监督学习-生成式方法(Generative Methods)
学习·机器学习·分类
自不量力的A同学2 天前
如何利用人工智能算法优化知识分类和标签?
人工智能·算法·分类
小舞O_o2 天前
RP2K:一个面向细粒度图像的大规模零售商品数据集
人工智能·pytorch·python·分类·数据集
lu_rong_qq2 天前
【LLM】一文了解 NLP 里程碑模型 BERT
人工智能·自然语言处理·bert