让我们一起来训练一个自动检测违规言论的大模型吧!

前言

请将你找到的资料传到我的github上!拜托了!

在这个互联网时代,违规言论似乎已经成为了我们日常上网的一部分。作为一名开发者,尤其是一名经济有限的开发者,我并不希望在自己的产品中看到这些内容。于是,我开始寻找解决方案。虽然市场上有人工审核或接入第三方服务,但考虑到成本问题,我决定亲自尝试训练一个AI模型,来实现自动检测并屏蔽违规言论。

于是,我选择了一个较小的预训练模型开始训练。然而,随着项目的推进,我遇到了一个问题:由于数据集的规模过小,模型的效果远未达到预期,准确率仅为71%。这一结果让我意识到,只有拥有更多高质量的训练数据,才能提升模型的性能。

在此,我呼吁大家贡献一份力量,帮助我寻找更多的训练数据集。通过共同努力,我们可以打造一个免费且高效的自动违规言论检测工具,让更多人受益。

AI模型训练

1、选择预训练模型

作为个人开发者,我目前仅有一台配备 24GB 显存的 4090 显卡可供使用。在资源有限的情况下,选择合适的预训练模型至关重要。经过权衡,我决定使用 distilbert-base-multilingual-cased 作为基础模型。这个模型不仅支持中文和英文,还具有较小的体积,非常适合我的需求,尤其是在为未来实现实时检测功能做好准备时,能更有效地节省计算资源。

2、数据处理

对数据格式进行了调整,并将数据集划分为训练集和验证集。

ini 复制代码
import pandas as pd
from sklearn.model_selection import train_test_split

# 处理数据格式
comments = pd.read_csv('data/hateComment.csv')
comments = comments.dropna(subset=['comment_text', 'label'])
comments['label'] = comments['label'].astype(int)

train_df, val_df = train_test_split(comments, test_size=0.2, random_state=42, stratify=comments['label'])
train_df.to_csv('data/train.csv', index=False)
val_df.to_csv('data/val.csv', index=False)

3、编写训练代码

ini 复制代码
import pandas as pd
from datasets import Dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch

def load_data(train_path, val_path):
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    return Dataset.from_pandas(train_df), Dataset.from_pandas(val_df)

def preprocess_data(tokenizer, examples):
    return tokenizer(examples['comment_text'], padding='max_length', truncation=True, max_length=128)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1}

def main():
    # 路径
    train_path = '/home/mmc/delock/data/train.csv'
    val_path = '/home/mmc/delock/data/val.csv'
    model_name = '/home/mmc/delock/model/distilbert-base-multilingual-cased'
    output_dir = '/home/mmc/delock/model/distilbert-swsr'

    # 加载数据
    train_dataset, val_dataset = load_data(train_path, val_path)

    # 加载分词器
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)

    # 分词
    train_dataset = train_dataset.map(lambda x: preprocess_data(tokenizer, x), batched=True)
    val_dataset = val_dataset.map(lambda x: preprocess_data(tokenizer, x), batched=True)

    # 设置格式为 PyTorch
    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
    val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    # 加载模型
    model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=2)

    # 训练参数
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=15,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=64,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_dir='../logs',
        logging_steps=10,
        load_best_model_at_end=True,
        metric_for_best_model='f1',
        greater_is_better=True
    )

    # 初始化 Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )

    # 训练
    trainer.train()

    # 评估
    eval_result = trainer.evaluate()
    print(f"评估结果: {eval_result}")

    # 保存最终模型
    trainer.save_model(output_dir)

if __name__ == "__main__":
    main()

4、测试模型效果

ini 复制代码
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
import torch

# 加载微调后的模型
tokenizer_path = '/home/mmc/delock/model/distilbert-base-multilingual-cased'
model_path = '/home/mmc/delock/model/distilbert-swsr'
tokenizer = DistilBertTokenizerFast.from_pretrained(tokenizer_path)
model = DistilBertForSequenceClassification.from_pretrained(model_path)

def predict(text):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=1).item()
    return prediction  # 1 表示性别歧视,0 表示非性别歧视

# 示例
comment = "你的智商真低!"
label = predict(comment)
print(f"标签: {label}")

附录

distilbert-base-multilingual-cased模型地址 huggingface.co/distilbert/...

违规言论数据仓库github地址 github.com/delcok/Ille...

相关推荐
张较瘦_1 分钟前
[论文阅读] 人工智能 | 5C提示词框架的研究
论文阅读·人工智能
超龄超能程序猿15 分钟前
使用 Python 对本地图片进行图像分类
开发语言·人工智能·python·机器学习·分类·数据挖掘·scipy
大千AI助手18 分钟前
RLHF:人类反馈强化学习 | 对齐AI与人类价值观的核心引擎
人工智能·深度学习·算法·机器学习·强化学习·rlhf·人类反馈强化学习
我爱一条柴ya29 分钟前
【AI大模型】RAG系统组件:向量数据库(ChromaDB)
数据库·人工智能·pytorch·python·ai·ai编程
MARS_AI_34 分钟前
云蝠智能VoiceAgent重构企业电话客服体系
人工智能·自然语言处理·人机交互·交互·信息与通信
在猴站学算法4 小时前
机器学习(西瓜书) 第二章 模型评估与选择
人工智能·机器学习
科技宅说5 小时前
36氪专访丨乐橙CEO谢运:AI科技下的业务创新与长期主义下的品牌坚守
人工智能·科技
学术小八6 小时前
2025年人工智能、虚拟现实与交互设计国际学术会议
人工智能·交互·vr
仗剑_走天涯7 小时前
基于pytorch.nn模块实现线性模型
人工智能·pytorch·python·深度学习
cnbestec8 小时前
协作机器人UR7e与UR12e:轻量化设计与高负载能力助力“小而美”智造升级
人工智能·机器人·协作机器人·ur协作机器人·ur7e·ur12e