BERT的中文问答系统28

为了使GUI界面更加人性化,我们从以下几个方面进行改进:

美化界面:使用更现代的样式和布局,增加图标和颜色。

用户反馈:增加更多的提示信息和反馈,让用户知道当前的操作状态。

功能增强:增加一些实用的功能,如历史记录搜索、导出日志等。

用户体验:优化交互流程,使操作更加流畅和直观。

下面是改进后的代码:

python 复制代码
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
    model.train()
    total_loss = 0.0
    num_batches = len(data_loader)
    for batch_idx, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            optimizer.zero_grad()
            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if progress_var:
                progress_var.set((batch_idx + 1) / num_batches * 100)
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader)

# 主训练函数
def main_train(retrain=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

    if retrain:
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

    num_epochs = 30
    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')

    torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
    logging.info("模型训练完成并保存")

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        # 历史记录
        self.history = []

        self.create_widgets()

    def create_widgets(self):
        # 顶部框架
        top_frame = tk.Frame(self.root)
        top_frame.pack(pady=10)

        self.question_label = tk.Label(top_frame, text="问题:", font=("Arial", 12))
        self.question_label.grid(row=0, column=0, padx=10)

        self.question_entry = tk.Entry(top_frame, width=50, font=("Arial", 12))
        self.question_entry.grid(row=0, column=1, padx=10)

        self.answer_button = tk.Button(top_frame, text="获取回答", command=self.get_answer, font=("Arial", 12))
        self.answer_button.grid(row=0, column=2, padx=10)

        # 中部框架
        middle_frame = tk.Frame(self.root)
        middle_frame.pack(pady=10)

        self.answer_label = tk.Label(middle_frame, text="回答:", font=("Arial", 12))
        self.answer_label.grid(row=0, column=0, padx=10)

        self.answer_text = tk.Text(middle_frame, height=10, width=70, font=("Arial", 12))
        self.answer_text.grid(row=1, column=0, padx=10)

        # 底部框架
        bottom_frame = tk.Frame(self.root)
        bottom_frame.pack(pady=10)

        self.correct_button = tk.Button(bottom_frame, text="准确", command=self.mark_correct, font=("Arial", 12))
        self.correct_button.grid(row=0, column=0, padx=10)

        self.incorrect_button = tk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, font=("Arial", 12))
        self.incorrect_button.grid(row=0, column=1, padx=10)

        self.train_button = tk.Button(bottom_frame, text="训练模型", command=self.train_model, font=("Arial", 12))
        self.train_button.grid(row=0, column=2, padx=10)

        self.retrain_button = tk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), font=("Arial", 12))
        self.retrain_button.grid(row=0, column=3, padx=10)

        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200)
        self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)

        self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
        self.log_text.grid(row=2, column=0, columnspan=4, pady=10)

        self.evaluate_button = tk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, font=("Arial", 12))
        self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)

        self.history_button = tk.Button(bottom_frame, text="查看历史记录", command=self.view_history, font=("Arial", 12))
        self.history_button.grid(row=3, column=1, padx=10, pady=10)

        self.save_history_button = tk.Button(bottom_frame, text="保存历史记录", command=self.save_history, font=("Arial", 12))
        self.save_history_button.grid(row=3, column=2, padx=10, pady=10)

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

        # 添加到历史记录
        self.history.append({
            'question': question,
            'answer_type': answer_type,
            'specific_answer': specific_answer,
            'accuracy': None  # 初始状态为未评价
        })

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            if retrain:
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 30
            for epoch in range(num_epochs):
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
                self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')
                self.log_text.see(tk.END)
            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")
            self.log_text.insert(tk.END, "模型训练完成并保存\n")
            self.log_text.see(tk.END)
            messagebox.showinfo("训练完成", "模型训练完成并保存")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
            self.log_text.see(tk.END)
            messagebox.showerror("训练失败", f"模型训练失败: {e}")

    def evaluate_model(self):
        # 这里可以添加模型评估的逻辑
        messagebox.showinfo("评估结果", "模型评估功能暂未实现")

    def mark_correct(self):
        if self.history:
            self.history[-1]['accuracy'] = True
            messagebox.showinfo("评价成功", "您认为这次回答是准确的")

    def mark_incorrect(self):
        if self.history:
            self.history[-1]['accuracy'] = False
            messagebox.showinfo("评价成功", "您认为这次回答是不准确的")

    def view_history(self):
        history_window = tk.Toplevel(self.root)
        history_window.title("历史记录")

        history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
        history_text.pack(padx=10, pady=10)

        for entry in self.history:
            history_text.insert(tk.END, f"问题: {entry['question']}\n")
            history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
            history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
            if entry['accuracy'] is None:
                history_text.insert(tk.END, "评价: 未评价\n")
            elif entry['accuracy']:
                history_text.insert(tk.END, "评价: 准确\n")
            else:
                history_text.insert(tk.END, "评价: 不准确\n")
            history_text.insert(tk.END, "-" * 50 + "\n")

    def save_history(self):
        file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
        if not file_path:
            return

        with open(file_path, 'w') as f:
            json.dump(self.history, f, ensure_ascii=False, indent=4)

        messagebox.showinfo("保存成功", "历史记录已保存到文件")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

改进点总结

美化界面:

使用更大的字体和更现代的布局。

增加了顶部、中部和底部框架,使界面更加整洁。
用户反馈:

在每个按钮点击后增加了提示信息,让用户知道当前的操作状态。

使用messagebox显示操作结果,增加用户的反馈体验。
功能增强:

增加了历史记录查看和保存功能。

增加了日志输出,方便用户了解模型训练和评估的状态。
用户体验:

优化了交互流程,使操作更加流畅和直观。

增加了进度条,让用户了解模型训练的进度。
希望这些改进能让你的聊天机器人界面更加友好和易用!

相关推荐
MidJourney中文版33 分钟前
深度报告:中老年AI陪伴机器人需求分析
人工智能·机器人
王上上1 小时前
【论文阅读41】-LSTM-PINN预测人口
论文阅读·人工智能·lstm
智慧化智能化数字化方案1 小时前
69页全面预算管理体系的框架与落地【附全文阅读】
大数据·人工智能·全面预算管理·智慧财务·智慧预算
PyAIExplorer1 小时前
图像旋转:从原理到 OpenCV 实践
人工智能·opencv·计算机视觉
Wilber的技术分享1 小时前
【机器学习实战笔记 14】集成学习:XGBoost算法(一) 原理简介与快速应用
人工智能·笔记·算法·随机森林·机器学习·集成学习·xgboost
巴里巴气1 小时前
selenium基础知识 和 模拟登录selenium版本
爬虫·python·selenium·爬虫模拟登录
19891 小时前
【零基础学AI】第26讲:循环神经网络(RNN)与LSTM - 文本生成
人工智能·python·rnn·神经网络·机器学习·tensorflow·lstm
JavaEdge在掘金2 小时前
Redis 数据倾斜?别慌!从成因到解决方案,一文帮你搞定
python
ansurfen2 小时前
我的第一个AI项目:从零搭建RAG知识库的踩坑之旅
python·llm
前端付豪2 小时前
20、用 Python + API 打造终端天气预报工具(支持城市查询、天气图标、美化输出🧊
后端·python