BERT的中文问答系统28

为了使GUI界面更加人性化,我们从以下几个方面进行改进:

美化界面:使用更现代的样式和布局,增加图标和颜色。

用户反馈:增加更多的提示信息和反馈,让用户知道当前的操作状态。

功能增强:增加一些实用的功能,如历史记录搜索、导出日志等。

用户体验:优化交互流程,使操作更加流畅和直观。

下面是改进后的代码:

python 复制代码
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
    model.train()
    total_loss = 0.0
    num_batches = len(data_loader)
    for batch_idx, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            optimizer.zero_grad()
            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if progress_var:
                progress_var.set((batch_idx + 1) / num_batches * 100)
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader)

# 主训练函数
def main_train(retrain=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

    if retrain:
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

    num_epochs = 30
    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')

    torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
    logging.info("模型训练完成并保存")

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        # 历史记录
        self.history = []

        self.create_widgets()

    def create_widgets(self):
        # 顶部框架
        top_frame = tk.Frame(self.root)
        top_frame.pack(pady=10)

        self.question_label = tk.Label(top_frame, text="问题:", font=("Arial", 12))
        self.question_label.grid(row=0, column=0, padx=10)

        self.question_entry = tk.Entry(top_frame, width=50, font=("Arial", 12))
        self.question_entry.grid(row=0, column=1, padx=10)

        self.answer_button = tk.Button(top_frame, text="获取回答", command=self.get_answer, font=("Arial", 12))
        self.answer_button.grid(row=0, column=2, padx=10)

        # 中部框架
        middle_frame = tk.Frame(self.root)
        middle_frame.pack(pady=10)

        self.answer_label = tk.Label(middle_frame, text="回答:", font=("Arial", 12))
        self.answer_label.grid(row=0, column=0, padx=10)

        self.answer_text = tk.Text(middle_frame, height=10, width=70, font=("Arial", 12))
        self.answer_text.grid(row=1, column=0, padx=10)

        # 底部框架
        bottom_frame = tk.Frame(self.root)
        bottom_frame.pack(pady=10)

        self.correct_button = tk.Button(bottom_frame, text="准确", command=self.mark_correct, font=("Arial", 12))
        self.correct_button.grid(row=0, column=0, padx=10)

        self.incorrect_button = tk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, font=("Arial", 12))
        self.incorrect_button.grid(row=0, column=1, padx=10)

        self.train_button = tk.Button(bottom_frame, text="训练模型", command=self.train_model, font=("Arial", 12))
        self.train_button.grid(row=0, column=2, padx=10)

        self.retrain_button = tk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), font=("Arial", 12))
        self.retrain_button.grid(row=0, column=3, padx=10)

        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200)
        self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)

        self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
        self.log_text.grid(row=2, column=0, columnspan=4, pady=10)

        self.evaluate_button = tk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, font=("Arial", 12))
        self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)

        self.history_button = tk.Button(bottom_frame, text="查看历史记录", command=self.view_history, font=("Arial", 12))
        self.history_button.grid(row=3, column=1, padx=10, pady=10)

        self.save_history_button = tk.Button(bottom_frame, text="保存历史记录", command=self.save_history, font=("Arial", 12))
        self.save_history_button.grid(row=3, column=2, padx=10, pady=10)

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

        # 添加到历史记录
        self.history.append({
            'question': question,
            'answer_type': answer_type,
            'specific_answer': specific_answer,
            'accuracy': None  # 初始状态为未评价
        })

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            if retrain:
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 30
            for epoch in range(num_epochs):
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
                self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')
                self.log_text.see(tk.END)
            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")
            self.log_text.insert(tk.END, "模型训练完成并保存\n")
            self.log_text.see(tk.END)
            messagebox.showinfo("训练完成", "模型训练完成并保存")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
            self.log_text.see(tk.END)
            messagebox.showerror("训练失败", f"模型训练失败: {e}")

    def evaluate_model(self):
        # 这里可以添加模型评估的逻辑
        messagebox.showinfo("评估结果", "模型评估功能暂未实现")

    def mark_correct(self):
        if self.history:
            self.history[-1]['accuracy'] = True
            messagebox.showinfo("评价成功", "您认为这次回答是准确的")

    def mark_incorrect(self):
        if self.history:
            self.history[-1]['accuracy'] = False
            messagebox.showinfo("评价成功", "您认为这次回答是不准确的")

    def view_history(self):
        history_window = tk.Toplevel(self.root)
        history_window.title("历史记录")

        history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
        history_text.pack(padx=10, pady=10)

        for entry in self.history:
            history_text.insert(tk.END, f"问题: {entry['question']}\n")
            history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
            history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
            if entry['accuracy'] is None:
                history_text.insert(tk.END, "评价: 未评价\n")
            elif entry['accuracy']:
                history_text.insert(tk.END, "评价: 准确\n")
            else:
                history_text.insert(tk.END, "评价: 不准确\n")
            history_text.insert(tk.END, "-" * 50 + "\n")

    def save_history(self):
        file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
        if not file_path:
            return

        with open(file_path, 'w') as f:
            json.dump(self.history, f, ensure_ascii=False, indent=4)

        messagebox.showinfo("保存成功", "历史记录已保存到文件")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

改进点总结

美化界面:

使用更大的字体和更现代的布局。

增加了顶部、中部和底部框架,使界面更加整洁。
用户反馈:

在每个按钮点击后增加了提示信息,让用户知道当前的操作状态。

使用messagebox显示操作结果,增加用户的反馈体验。
功能增强:

增加了历史记录查看和保存功能。

增加了日志输出,方便用户了解模型训练和评估的状态。
用户体验:

优化了交互流程,使操作更加流畅和直观。

增加了进度条,让用户了解模型训练的进度。
希望这些改进能让你的聊天机器人界面更加友好和易用!

相关推荐
新智元7 分钟前
ICML 2026史上最严新规:LLM不得列为作者,滥用AI直接退稿
人工智能·openai
ada7_12 分钟前
LeetCode(python)——49.字母异位词分组
java·python·leetcode
后端小肥肠14 分钟前
10W+育儿漫画是怎么做的?我用n8n搭建了自动化工作流,3分钟生成到本地磁盘
人工智能·aigc·agent
我的xiaodoujiao19 分钟前
使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 23--数据驱动--参数化处理 Yaml 文件
python·学习·测试工具·pytest
晨尘光27 分钟前
【pycharm 创建一个线程,在线程函数中增加的日志打印,日志打印了,但是打断点进不去】
ide·python·pycharm
钛投标免费AI标书工具39 分钟前
【官方认证】2025年AI标书工具:免费、零废标、安全
大数据·人工智能·安全
盼小辉丶43 分钟前
视觉Transformer实战——Vision Transformer(ViT)详解与实现
人工智能·深度学习·transformer
databook1 小时前
manim边做边学--文字创建销毁的打字机效果
后端·python·动效
爱思德学术1 小时前
第二届中欧科学家论坛暨第七届人工智能与先进制造国际会议(AIAM 2025)在德国海德堡成功举办
人工智能·算法·机器学习·语言模型
小艳加油1 小时前
AI+Python近红外光谱分析机器学习与深度学习实战,覆盖提示词撰写、数据预处理、回归/神经网络/集成学习/迁移学习/可解释性可视化等
python·近红外光谱分析·多元线性回归