BERT的中文问答系统28

为了使GUI界面更加人性化,我们从以下几个方面进行改进:

美化界面:使用更现代的样式和布局,增加图标和颜色。

用户反馈:增加更多的提示信息和反馈,让用户知道当前的操作状态。

功能增强:增加一些实用的功能,如历史记录搜索、导出日志等。

用户体验:优化交互流程,使操作更加流畅和直观。

下面是改进后的代码:

python 复制代码
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
    model.train()
    total_loss = 0.0
    num_batches = len(data_loader)
    for batch_idx, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            optimizer.zero_grad()
            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if progress_var:
                progress_var.set((batch_idx + 1) / num_batches * 100)
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader)

# 主训练函数
def main_train(retrain=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

    if retrain:
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

    num_epochs = 30
    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')

    torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
    logging.info("模型训练完成并保存")

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        # 历史记录
        self.history = []

        self.create_widgets()

    def create_widgets(self):
        # 顶部框架
        top_frame = tk.Frame(self.root)
        top_frame.pack(pady=10)

        self.question_label = tk.Label(top_frame, text="问题:", font=("Arial", 12))
        self.question_label.grid(row=0, column=0, padx=10)

        self.question_entry = tk.Entry(top_frame, width=50, font=("Arial", 12))
        self.question_entry.grid(row=0, column=1, padx=10)

        self.answer_button = tk.Button(top_frame, text="获取回答", command=self.get_answer, font=("Arial", 12))
        self.answer_button.grid(row=0, column=2, padx=10)

        # 中部框架
        middle_frame = tk.Frame(self.root)
        middle_frame.pack(pady=10)

        self.answer_label = tk.Label(middle_frame, text="回答:", font=("Arial", 12))
        self.answer_label.grid(row=0, column=0, padx=10)

        self.answer_text = tk.Text(middle_frame, height=10, width=70, font=("Arial", 12))
        self.answer_text.grid(row=1, column=0, padx=10)

        # 底部框架
        bottom_frame = tk.Frame(self.root)
        bottom_frame.pack(pady=10)

        self.correct_button = tk.Button(bottom_frame, text="准确", command=self.mark_correct, font=("Arial", 12))
        self.correct_button.grid(row=0, column=0, padx=10)

        self.incorrect_button = tk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, font=("Arial", 12))
        self.incorrect_button.grid(row=0, column=1, padx=10)

        self.train_button = tk.Button(bottom_frame, text="训练模型", command=self.train_model, font=("Arial", 12))
        self.train_button.grid(row=0, column=2, padx=10)

        self.retrain_button = tk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), font=("Arial", 12))
        self.retrain_button.grid(row=0, column=3, padx=10)

        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200)
        self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)

        self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
        self.log_text.grid(row=2, column=0, columnspan=4, pady=10)

        self.evaluate_button = tk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, font=("Arial", 12))
        self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)

        self.history_button = tk.Button(bottom_frame, text="查看历史记录", command=self.view_history, font=("Arial", 12))
        self.history_button.grid(row=3, column=1, padx=10, pady=10)

        self.save_history_button = tk.Button(bottom_frame, text="保存历史记录", command=self.save_history, font=("Arial", 12))
        self.save_history_button.grid(row=3, column=2, padx=10, pady=10)

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

        # 添加到历史记录
        self.history.append({
            'question': question,
            'answer_type': answer_type,
            'specific_answer': specific_answer,
            'accuracy': None  # 初始状态为未评价
        })

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            if retrain:
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 30
            for epoch in range(num_epochs):
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
                self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')
                self.log_text.see(tk.END)
            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")
            self.log_text.insert(tk.END, "模型训练完成并保存\n")
            self.log_text.see(tk.END)
            messagebox.showinfo("训练完成", "模型训练完成并保存")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
            self.log_text.see(tk.END)
            messagebox.showerror("训练失败", f"模型训练失败: {e}")

    def evaluate_model(self):
        # 这里可以添加模型评估的逻辑
        messagebox.showinfo("评估结果", "模型评估功能暂未实现")

    def mark_correct(self):
        if self.history:
            self.history[-1]['accuracy'] = True
            messagebox.showinfo("评价成功", "您认为这次回答是准确的")

    def mark_incorrect(self):
        if self.history:
            self.history[-1]['accuracy'] = False
            messagebox.showinfo("评价成功", "您认为这次回答是不准确的")

    def view_history(self):
        history_window = tk.Toplevel(self.root)
        history_window.title("历史记录")

        history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
        history_text.pack(padx=10, pady=10)

        for entry in self.history:
            history_text.insert(tk.END, f"问题: {entry['question']}\n")
            history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
            history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
            if entry['accuracy'] is None:
                history_text.insert(tk.END, "评价: 未评价\n")
            elif entry['accuracy']:
                history_text.insert(tk.END, "评价: 准确\n")
            else:
                history_text.insert(tk.END, "评价: 不准确\n")
            history_text.insert(tk.END, "-" * 50 + "\n")

    def save_history(self):
        file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
        if not file_path:
            return

        with open(file_path, 'w') as f:
            json.dump(self.history, f, ensure_ascii=False, indent=4)

        messagebox.showinfo("保存成功", "历史记录已保存到文件")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

改进点总结

美化界面:

使用更大的字体和更现代的布局。

增加了顶部、中部和底部框架,使界面更加整洁。
用户反馈:

在每个按钮点击后增加了提示信息,让用户知道当前的操作状态。

使用messagebox显示操作结果,增加用户的反馈体验。
功能增强:

增加了历史记录查看和保存功能。

增加了日志输出,方便用户了解模型训练和评估的状态。
用户体验:

优化了交互流程,使操作更加流畅和直观。

增加了进度条,让用户了解模型训练的进度。
希望这些改进能让你的聊天机器人界面更加友好和易用!

相关推荐
野蛮的大西瓜18 分钟前
开源呼叫中心中,如何将ASR与IVR菜单结合,实现动态的IVR交互
人工智能·机器人·自动化·音视频·信息与通信
CountingStars61943 分钟前
目标检测常用评估指标(metrics)
人工智能·目标检测·目标跟踪
tangjunjun-owen1 小时前
第四节:GLM-4v-9b模型的tokenizer源码解读
人工智能·glm-4v-9b·多模态大模型教程
冰蓝蓝1 小时前
深度学习中的注意力机制:解锁智能模型的新视角
人工智能·深度学习
橙子小哥的代码世界1 小时前
【计算机视觉基础CV-图像分类】01- 从历史源头到深度时代:一文读懂计算机视觉的进化脉络、核心任务与产业蓝图
人工智能·计算机视觉
黄公子学安全1 小时前
Java的基础概念(一)
java·开发语言·python
新加坡内哥谈技术2 小时前
苏黎世联邦理工学院与加州大学伯克利分校推出MaxInfoRL:平衡内在与外在探索的全新强化学习框架
大数据·人工智能·语言模型
程序员一诺2 小时前
【Python使用】嘿马python高级进阶全体系教程第10篇:静态Web服务器-返回固定页面数据,1. 开发自己的静态Web服务器【附代码文档】
后端·python
小木_.2 小时前
【Python 图片下载器】一款专门为爬虫制作的图片下载器,多线程下载,速度快,支持续传/图片缩放/图片压缩/图片转换
爬虫·python·学习·分享·批量下载·图片下载器
fanstuck2 小时前
Prompt提示工程上手指南(七)Prompt编写实战-基于智能客服问答系统下的Prompt编写
人工智能·数据挖掘·openai