于BERT的中文问答系统11

该代码实现了一个基于BERT的聊天机器人,具有以下主要功能:

数据加载:

从JSONL或JSON文件中加载训练数据。

数据集包含问题、人类回答和ChatGPT回答。
模型定义:

使用预训练的BERT模型作为基础模型。

添加了一个分类器层,用于区分人类回答和ChatGPT回答。
训练:

定义了训练函数,使用BCEWithLogitsLoss作为损失函数。

训练模型以区分人类回答和ChatGPT回答。
GUI界面:

提供了一个简单的图形用户界面,用户可以输入问题并获取回答。

提供了训练模型和重新训练模型的功能。

python 复制代码
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for item in reader:
                    data.append(item)
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                data = json.load(f)
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        human_input_ids = batch['human_input_ids'].to(device)
        human_attention_mask = batch['human_attention_mask'].to(device)
        chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
        chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

        optimizer.zero_grad()
        human_logits = model(human_input_ids, human_attention_mask)
        chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

        human_labels = torch.ones(human_logits.size(0), 1).to(device)
        chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

        loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)

# 主训练函数
def main_train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

    num_epochs = 5
    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')

    torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
    logging.info("模型训练完成并保存")

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.create_widgets()

    def create_widgets(self):
        self.question_label = tk.Label(self.root, text="问题:")
        self.question_label.pack()

        self.question_entry = tk.Entry(self.root, width=50)
        self.question_entry.pack()

        self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)
        self.answer_button.pack()

        self.answer_label = tk.Label(self.root, text="回答:")
        self.answer_label.pack()

        self.answer_text = tk.Text(self.root, height=10, width=50)
        self.answer_text.pack()

        self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)
        self.train_button.pack()

    def get_answer(self):
        question = self.question_entry.get()
        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "人类回答"
        else:
            answer_type = "ChatGPT回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

    def get_specific_answer(self, question, answer_type):
        # 从数据集中查找具体的答案
        for item in self.data:
            if item['question'] == question:
                if answer_type == "人类回答":
                    return item['human_answers'][0]
                else:
                    return item['chatgpt_answers'][0]
        return "未找到具体答案"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for item in reader:
                    data.append(item)
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                data = json.load(f)
        return data

    def train_model(self):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if file_path:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
            self.model.to(self.device)
            self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 5
            for epoch in range(num_epochs):
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")

# 主函数
if __name__ == "__main__":
    # 训练模型
    main_train()

    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

注意事项

数据格式:

确保训练数据文件(JSONL或JSON)的格式正确,每条记录应包含question、human_answers和chatgpt_answers字段。
模型路径:

确保预训练的BERT模型路径(F:/models/bert-base-chinese)和模型保存路径(F:/models/xihua_model.pth)正确无误。
设备选择:

代码会自动选择CUDA设备(如果有),否则使用CPU。确保系统中有可用的CUDA设备或足够的CPU资源。
日志记录:

日志记录配置为INFO级别,记录训练过程中的信息。确保日志文件路径可写。
错误处理:

目前代码中没有详细的错误处理机制。

相关推荐
可涵不会debug1 分钟前
《“慧眼识障“:基于Rokid AI眼镜的智能维修记录自动归档系统开发实战》
人工智能
xieyan08114 分钟前
什么情况下使用强化学习
人工智能
腾飞开源4 分钟前
04_Spring AI 干货笔记之对话客户端 API
人工智能·元数据·检索增强生成·spring ai·chatclient·对话记忆·流式api
执笔论英雄5 分钟前
【RL】Slime异步原理(单例设计模式)6
人工智能·设计模式
da_vinci_x6 分钟前
PS 结构参考 + Firefly:零建模量产 2.5D 等轴游戏资产
人工智能·游戏·设计模式·prompt·aigc·技术美术·游戏美术
java1234_小锋6 分钟前
基于Python深度学习的车辆车牌识别系统(PyTorch2卷积神经网络CNN+OpenCV4实现)视频教程 - 切割车牌矩阵获取车牌字符
python·深度学习·cnn·车牌识别
是小崔啊12 分钟前
【SAA】01 - Spring Ai Alibaba快速入门
java·人工智能·spring
semantist@语校13 分钟前
第五十一篇|构建日本语言学校数据模型:埼玉国际学院的城市结构与行为变量分析
java·大数据·数据库·人工智能·百度·ai·github
想要成为计算机高手13 分钟前
π*0.6: 从实践中学习 -- 2025.11.17 -- Physical Intelligence (π) -- 未开源
人工智能·学习·机器人·多模态·具身智能·vla
陈文锦丫16 分钟前
Boundary Difference Over Union Loss For Medical Image Segmentation
深度学习