BERT的中文问答系统35

优化GUI布局、显示问答内容、增加自动搜索功能等。以下是完整的项目结构和代码:

项目结构

python 复制代码
xihe241117/
├── data/
│   └── train_data.jsonl
├── logs/
├── models/
│   └── xihua_model.pth
├── requirements.txt
├── README.md
└── 智能羲和.py

README.md

python 复制代码
# 羲和聊天机器人

## 项目介绍
羲和聊天机器人是一个基于BERT的中文问答系统,支持用户提问并获取回答。如果模型提供的回答不满意,用户可以选择"不正确",机器人将自动从百度搜索相关信息并提供更详细的答案。

## 目录结构
xihe241117/
├── data/
│ └── train_data.jsonl
├── logs/
├── models/
│ └── xihua_model.pth
├── requirements.txt
├── README.md
└── 智能羲和.py

## 安装依赖

pip install -r requirements.txt
运行项目

python 智能羲和.py
功能
用户提问
模型提供回答
用户评价回答(正确/不正确)
如果回答不正确,自动从百度搜索相关信息
查看历史记录
保存历史记录
训练模型
重新训练模型
评估模型(暂未实现)
联系我们
如有任何问题或建议,请联系 [[email protected]]。

requirements.txt

txt 复制代码
torch
transformers
jsonlines
tkinter
requests
beautifulsoup4

智能羲和.py

python 复制代码
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
import requests
from bs4 import BeautifulSoup

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d_%H-%M-%S_羲和.txt'))
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, progress_var=None):
    model.train()
    total_loss = 0.0
    num_batches = len(data_loader)
    for batch_idx, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            optimizer.zero_grad()
            human_logits = model(human_input_ids, human_attention_mask)
            chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

            human_labels = torch.ones(human_logits.size(0), 1).to(device)
            chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

            loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            if progress_var:
                progress_var.set((batch_idx + 1) / num_batches * 100)
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader)

# 主训练函数
def main_train(retrain=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'使用设备: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

    if retrain:
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)

    num_epochs = 30
    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')

    torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
    logging.info("模型训练完成并保存")

# 网络搜索函数
def search_baidu(query):
    url = f"https://www.baidu.com/s?wd={query}"
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
    }
    response = requests.get(url, headers=headers)
    soup = BeautifulSoup(response.text, 'html.parser')
    results = soup.find_all('div', class_='c-abstract')
    if results:
        return results[0].get_text().strip()
    return "没有找到相关信息"

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        # 历史记录
        self.history = []

        self.create_widgets()

    def create_widgets(self):
        # 设置样式
        style = ttk.Style()
        style.theme_use('clam')

        # 顶部框架
        top_frame = ttk.Frame(self.root)
        top_frame.pack(pady=10)

        self.question_label = ttk.Label(top_frame, text="问题:", font=("Arial", 12))
        self.question_label.grid(row=0, column=0, padx=10)

        self.question_entry = ttk.Entry(top_frame, width=50, font=("Arial", 12))
        self.question_entry.grid(row=0, column=1, padx=10)

        self.answer_button = ttk.Button(top_frame, text="获取回答", command=self.get_answer, style='TButton')
        self.answer_button.grid(row=0, column=2, padx=10)

        # 中部框架
        middle_frame = ttk.Frame(self.root)
        middle_frame.pack(pady=10)

        self.question_text = tk.Text(middle_frame, height=10, width=35, font=("Arial", 12), wrap='word')
        self.question_text.grid(row=0, column=0, padx=10, pady=10)
        self.question_text.tag_configure("right", justify='right')

        self.answer_text = tk.Text(middle_frame, height=10, width=35, font=("Arial", 12), wrap='word')
        self.answer_text.grid(row=0, column=1, padx=10, pady=10)
        self.answer_text.tag_configure("left", justify='left')

        # 底部框架
        bottom_frame = ttk.Frame(self.root)
        bottom_frame.pack(pady=10)

        self.correct_button = ttk.Button(bottom_frame, text="准确", command=self.mark_correct, style='TButton')
        self.correct_button.grid(row=0, column=0, padx=10)

        self.incorrect_button = ttk.Button(bottom_frame, text="不准确", command=self.mark_incorrect, style='TButton')
        self.incorrect_button.grid(row=0, column=1, padx=10)

        self.train_button = ttk.Button(bottom_frame, text="训练模型", command=self.train_model, style='TButton')
        self.train_button.grid(row=0, column=2, padx=10)

        self.retrain_button = ttk.Button(bottom_frame, text="重新训练模型", command=lambda: self.train_model(retrain=True), style='TButton')
        self.retrain_button.grid(row=0, column=3, padx=10)

        self.progress_var = tk.DoubleVar()
        self.progress_bar = ttk.Progressbar(bottom_frame, variable=self.progress_var, maximum=100, length=200, mode='determinate')
        self.progress_bar.grid(row=1, column=0, columnspan=4, pady=10)

        self.log_text = tk.Text(bottom_frame, height=10, width=70, font=("Arial", 12))
        self.log_text.grid(row=2, column=0, columnspan=4, pady=10)

        self.evaluate_button = ttk.Button(bottom_frame, text="评估模型", command=self.evaluate_model, style='TButton')
        self.evaluate_button.grid(row=3, column=0, padx=10, pady=10)

        self.history_button = ttk.Button(bottom_frame, text="查看历史记录", command=self.view_history, style='TButton')
        self.history_button.grid(row=3, column=1, padx=10, pady=10)

        self.save_history_button = ttk.Button(bottom_frame, text="保存历史记录", command=self.save_history, style='TButton')
        self.save_history_button.grid(row=3, column=2, padx=10, pady=10)

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.question_text.insert(tk.END, f"用户: {question}\n", "right")
        self.answer_text.insert(tk.END, f"羲和: {specific_answer}\n", "left")

        # 添加到历史记录
        self.history.append({
            'question': question,
            'answer_type': answer_type,
            'specific_answer': specific_answer,
            'accuracy': None  # 初始状态为未评价
        })

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = DataLoader(dataset, batch_size=8, shuffle=True)
            
            # 加载已训练的模型权重
            if retrain:
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            num_epochs = 30
            for epoch in range(num_epochs):
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device, self.progress_var)
                logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
                self.log_text.insert(tk.END, f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}\n')
                self.log_text.see(tk.END)
            torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))
            logging.info("模型训练完成并保存")
            self.log_text.insert(tk.END, "模型训练完成并保存\n")
            self.log_text.see(tk.END)
            messagebox.showinfo("训练完成", "模型训练完成并保存")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            self.log_text.insert(tk.END, f"模型训练失败: {e}\n")
            self.log_text.see(tk.END)
            messagebox.showerror("训练失败", f"模型训练失败: {e}")

    def evaluate_model(self):
        # 这里可以添加模型评估的逻辑
        messagebox.showinfo("评估结果", "模型评估功能暂未实现")

    def mark_correct(self):
        if self.history:
            self.history[-1]['accuracy'] = True
            messagebox.showinfo("评价成功", "您认为这次回答是准确的")

    def mark_incorrect(self):
        if self.history:
            self.history[-1]['accuracy'] = False
            question = self.history[-1]['question']
            answer = search_baidu(question)
            self.answer_text.insert(tk.END, f"搜索引擎结果: {answer}\n", "left")
            messagebox.showinfo("评价成功", "您认为这次回答是不准确的")

    def view_history(self):
        history_window = tk.Toplevel(self.root)
        history_window.title("历史记录")

        history_text = tk.Text(history_window, height=20, width=80, font=("Arial", 12))
        history_text.pack(padx=10, pady=10)

        for entry in self.history:
            history_text.insert(tk.END, f"问题: {entry['question']}\n")
            history_text.insert(tk.END, f"回答类型: {entry['answer_type']}\n")
            history_text.insert(tk.END, f"具体回答: {entry['specific_answer']}\n")
            if entry['accuracy'] is None:
                history_text.insert(tk.END, "评价: 未评价\n")
            elif entry['accuracy']:
                history_text.insert(tk.END, "评价: 准确\n")
            else:
                history_text.insert(tk.END, "评价: 不准确\n")
            history_text.insert(tk.END, "-" * 50 + "\n")

    def save_history(self):
        file_path = filedialog.asksaveasfilename(defaultextension=".json", filetypes=[("JSON files", "*.json")])
        if not file_path:
            return

        with open(file_path, 'w') as f:
            json.dump(self.history, f, ensure_ascii=False, indent=4)

        messagebox.showinfo("保存成功", "历史记录已保存到文件")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

说明

优化布局:将问题和回答分别显示在两个 Text 组件中,左边显示用户的输入,右边显示羲和的回答。

增加提示信息:在按钮和输入框上增加了提示信息,帮助用户更好地理解如何使用。

自动搜索功能:在 mark_incorrect 方法中,如果用户标记回答为不准确,将调用 search_baidu 函数获取更详细的信息并显示在回答框中。

固定显示为中文:确保所有界面元素的文本都显示为中文。

这样,您的聊天机器人的GUI将更加美观和用户友好。希望这些改进能满足您的需求!如果有任何其他问题或需要进一步的定制,请告诉我。

相关推荐
CopyLower23 分钟前
Java与AI技术结合:从机器学习到生成式AI的实践
java·人工智能·机器学习
workflower32 分钟前
使用谱聚类将相似度矩阵分为2类
人工智能·深度学习·算法·机器学习·设计模式·软件工程·软件需求
jndingxin36 分钟前
OpenCV CUDA 模块中在 GPU 上对图像或矩阵进行 翻转(镜像)操作的一个函数 flip()
人工智能·opencv
囚生CY1 小时前
【速写】TRL:Trainer的细节与思考(PPO/DPO+LoRA可行性)
人工智能
杨德兴1 小时前
3.3 阶数的作用
人工智能·学习
望获linux1 小时前
医疗实时操作系统方案:手术机器人的微秒级运动控制
人工智能·机器人·实时操作系统·rtos·嵌入式软件·医疗自动化
仓颉编程语言1 小时前
仓颉Magic亮相GOSIM AI Paris 2025:掀起开源AI框架新热潮
人工智能·华为·开源·鸿蒙·仓颉编程语言
攻城狮7号1 小时前
一文理清人工智能,机器学习,深度学习的概念
人工智能·深度学习·机器学习·ai
智慧地球(AI·Earth)2 小时前
当 Manus AI 遇上 OpenAI Operator,谁能更胜一筹?
人工智能
小森77672 小时前
(七)深度学习---神经网络原理与实现
人工智能·深度学习·神经网络·算法