于BERT的中文问答系统11

该代码实现了一个基于BERT的聊天机器人,具有以下主要功能:

数据加载:

从JSONL或JSON文件中加载训练数据。

数据集包含问题、人类回答和ChatGPT回答。
模型定义:

使用预训练的BERT模型作为基础模型。

添加了一个分类器层,用于区分人类回答和ChatGPT回答。
训练:

定义了训练函数,使用BCEWithLogitsLoss作为损失函数。

训练模型以区分人类回答和ChatGPT回答。
GUI界面:

提供了一个简单的图形用户界面,用户可以输入问题并获取回答。

提供了训练模型和重新训练模型的功能。

python 复制代码
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for item in reader:
                    data.append(item)
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                data = json.load(f)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        human_input_ids = batch['human_input_ids'].to(device)
        human_attention_mask = batch['human_attention_mask'].to(device)
        chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
        chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

        optimizer.zero_grad()
        human_logits = model(human_input_ids, human_attention_mask)
        chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

        human_labels = torch.ones(human_logits.size(0), 1).to(device)
        chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

        loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)

# 主训练函数
def main_train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

    num_epochs = 5
    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')

    torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
    logging.info("模型训练完成并保存")

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.create_widgets()

    def create_widgets(self):
        self.question_label = tk.Label(self.root, text="问题:")
        self.question_label.pack()

        self.question_entry = tk.Entry(self.root, width=50)
        self.question_entry.pack()

        self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)
        self.answer_button.pack()

        self.answer_label = tk.Label(self.root, text="回答:")
        self.answer_label.pack()

        self.answer_text = tk.Text(self.root, height=10, width=50)
        self.answer_text.pack()

        self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)
        self.train_button.pack()

    def get_answer(self):
        question = self.question_entry.get()
        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "人类回答"
        else:
            answer_type = "ChatGPT回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

    def get_specific_answer(self, question, answer_type):
        # 从数据集中查找具体的答案
        for item in self.data:
            if item['question'] == question:
                if answer_type == "人类回答":
                    return item['human_answers'][0]
                else:
                    return item['chatgpt_answers'][0]
        return "未找到具体答案"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for item in reader:
                    data.append(item)
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                data = json.load(f)
        return data

    def train_model(self):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if file_path:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
            self.model.to(self.device)
            self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 5
            for epoch in range(num_epochs):
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")

# 主函数
if __name__ == "__main__":
    # 训练模型
    main_train()

    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

注意事项

数据格式:

确保训练数据文件(JSONL或JSON)的格式正确,每条记录应包含question、human_answers和chatgpt_answers字段。
模型路径:

确保预训练的BERT模型路径(F:/models/bert-base-chinese)和模型保存路径(F:/models/xihua_model.pth)正确无误。
设备选择:

代码会自动选择CUDA设备(如果有),否则使用CPU。确保系统中有可用的CUDA设备或足够的CPU资源。
日志记录:

日志记录配置为INFO级别,记录训练过程中的信息。确保日志文件路径可写。
错误处理:

目前代码中没有详细的错误处理机制。

相关推荐
Keano Reurink6 分钟前
SEO数据管道:用Airflow搭建自动化工作流
运维·人工智能·爬虫·搜索引擎·自动化·ai编程·seo
生成论实验室6 分钟前
用事件关系网络重新理解AI(二):损失函数、优化器与深度学习的动力学
数据结构·人工智能·深度学习·算法·语言模型
韦胖漫谈IT10 分钟前
提示词注入- 大语言模型 OWASP TOP 10系列
网络·人工智能·语言模型·大模型安全·owasp
HIT_Weston17 分钟前
93、【Agent】【OpenCode】edit 工具提示词(二)
人工智能·agent·opencode
xingyuzhisuan20 分钟前
2026年GPU租用平台JupyterHub多用户环境配置
服务器·人工智能·jupyter·gpu算力
生成论实验室24 分钟前
事件、信息荷与六维态势空间——每一个事件都是一次空间的弯曲
人工智能·算法·语言模型·可信计算技术·安全架构
徐安安ye28 分钟前
FlashAttention长程依赖建模:局部+全局的Hybrid Spiral结构设计
python·深度学习·机器学习
Zevalin爱灰灰36 分钟前
智能控制 第五章——神经网络控制论
人工智能·神经网络
韦胖漫谈IT37 分钟前
供应链 - 大语言模型 OWASP TOP 10系列
人工智能·语言模型·自然语言处理
KaMeidebaby41 分钟前
卡梅德生物技术快报|真核蛋白表达信号肽筛选实验全流程复盘
服务器·前端·数据库·人工智能·算法