BERT的中文问答系统22

理解了你的问题。即使日志目录存在且日志可以保存,仍然出现报错的原因可能是路径字符串格式的问题,或者是在某些情况下路径没有正确传递。我们可以进一步优化路径处理逻辑,确保路径始终是正确的。

以下是修订后的代码,重点在于确保路径处理的正确性和一致性:

python

python 复制代码
import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox, scrolledtext, ttk
import logging
from difflib import SequenceMatcher
from datetime import datetime
from torch.cuda.amp import GradScaler, autocast
import torch.multiprocessing as mp
import psutil
import torch.distributed as dist

# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)

def setup_logging():
    log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )

setup_logging()

# 数据集类
class XihuaDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_data(file_path)

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item['question']
        human_answer = item['human_answers'][0]
        chatgpt_answer = item['chatgpt_answers'][0]

        try:
            inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
        except Exception as e:
            logging.warning(f"跳过无效项 {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self.data))

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'human_input_ids': human_inputs['input_ids'].squeeze(),
            'human_attention_mask': human_inputs['attention_mask'].squeeze(),
            'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),
            'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),
            'human_answer': human_answer,
            'chatgpt_answer': chatgpt_answer
        }

# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128, distributed=False, num_workers=4):
    dataset = XihuaDataset(file_path, tokenizer, max_length)
    if distributed:
        sampler = DistributedSampler(dataset)
        return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    else:
        return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# 模型定义
class XihuaModel(torch.nn.Module):
    def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):
        super(XihuaModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        logits = self.classifier(pooled_output)
        return logits

# 训练函数
def train(model, data_loader, optimizer, criterion, device, scaler=None, gradient_accumulation_steps=1):
    model.train()
    total_loss = 0.0
    optimizer.zero_grad()
    for step, batch in enumerate(data_loader):
        try:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            human_input_ids = batch['human_input_ids'].to(device)
            human_attention_mask = batch['human_attention_mask'].to(device)
            chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)
            chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)

            with autocast():  # 使用自动混合精度
                human_logits = model(human_input_ids, human_attention_mask)
                chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)

                human_labels = torch.ones(human_logits.size(0), 1).to(device)
                chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)

                loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)

            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps

            scaler.scale(loss).backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item()
        except Exception as e:
            logging.warning(f"跳过无效批次: {e}")

    return total_loss / len(data_loader)

# 主训练函数
def main_train(rank, world_size, retrain=False, multi_gpu=False):
    if multi_gpu:
        dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
        torch.cuda.set_device(rank)

    device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device: {device}')

    tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
    model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)

    if multi_gpu:
        model = DDP(model, device_ids=[rank])

    if retrain:
        model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))
        model.to(device)
        model.train()

    model.bert.gradient_checkpointing_enable()

    optimizer = optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.BCEWithLogitsLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
    scaler = torch.amp.GradScaler('cuda')

    max_memory = torch.cuda.get_device_properties(device).total_memory * 0.9 if torch.cuda.is_available() else float('inf')
    batch_size = get_max_batch_size(model, device, max_memory)
    logging.info(f'Using batch size: {batch_size}')

    train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=batch_size, max_length=128, distributed=multi_gpu, num_workers=4)

    num_epochs = 3
    gradient_accumulation_steps = 2  # 梯度累积步骤
    best_loss = float('inf')
    best_model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
    os.makedirs(os.path.dirname(best_model_path), exist_ok=True)  # 确保模型目录存在

    writer = SummaryWriter(log_dir=os.path.join(PROJECT_ROOT, 'logs/tensorboard'))

    for epoch in range(num_epochs):
        train_loss = train(model, train_data_loader, optimizer, criterion, device, scaler, gradient_accumulation_steps)
        logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.8f}')
        writer.add_scalar('Training Loss', train_loss, epoch)
        scheduler.step(train_loss)

        if rank == 0:
            if train_loss < best_loss:
                best_loss = train_loss
                torch.save(model.state_dict(), best_model_path)
                logging.info(f"模型在 Epoch {epoch+1} 更新,Loss: {train_loss:.8f}")

    if rank == 0:
        logging.info("模型训练完成并保存")

    if multi_gpu:
        dist.destroy_process_group()

# 动态调整批大小
def get_max_batch_size(model, device, max_memory=1024 * 1024 * 1024):  # 默认最大显存为1GB
    batch_size = 1
    while True:
        try:
            input_ids = torch.randint(0, 100, (batch_size, 128)).to(device)
            attention_mask = torch.ones(batch_size, 128).to(device)
            with torch.no_grad():
                model(input_ids, attention_mask)
            batch_size *= 2
        except RuntimeError:
            return batch_size // 2

# 启动多GPU训练
def launch_training(retrain=False, multi_gpu=False):
    if multi_gpu and torch.cuda.device_count() > 1:
        world_size = torch.cuda.device_count()
        mp.spawn(main_train, args=(world_size, retrain, multi_gpu), nprocs=world_size, join=True)
    else:
        main_train(0, 1, retrain, multi_gpu)

# GUI界面
class XihuaChatbotGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("羲和聊天机器人")

        self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)
        self.load_model()
        self.model.eval()

        # 加载训练数据集以便在获取答案时使用
        self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))

        self.create_widgets()

    def create_widgets(self):
        self.question_label = tk.Label(self.root, text="问题:")
        self.question_label.pack()

        self.question_entry = tk.Entry(self.root, width=50)
        self.question_entry.pack()

        self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)
        self.answer_button.pack()

        self.answer_label = tk.Label(self.root, text="回答:")
        self.answer_label.pack()

        self.answer_text = scrolledtext.ScrolledText(self.root, height=10, width=50)
        self.answer_text.pack()

        self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)
        self.train_button.pack()

        self.retrain_button = tk.Button(self.root, text="重新训练模型", command=lambda: self.train_model(retrain=True))
        self.retrain_button.pack()

        self.multi_gpu_var = tk.BooleanVar()
        self.multi_gpu_checkbox = tk.Checkbutton(self.root, text="使用多GPU", variable=self.multi_gpu_var)
        self.multi_gpu_checkbox.pack()

        self.log_text = scrolledtext.ScrolledText(self.root, height=10, width=50)
        self.log_text.pack()

        self.progress_bar = ttk.Progressbar(self.root, orient='horizontal', length=300, mode='determinate')
        self.progress_bar.pack()

    def get_answer(self):
        question = self.question_entry.get()
        if not question:
            messagebox.showwarning("输入错误", "请输入问题")
            return

        inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
        with torch.no_grad():
            input_ids = inputs['input_ids'].to(self.device)
            attention_mask = inputs['attention_mask'].to(self.device)
            logits = self.model(input_ids, attention_mask)
        
        if logits.item() > 0:
            answer_type = "羲和回答"
        else:
            answer_type = "零回答"

        specific_answer = self.get_specific_answer(question, answer_type)

        self.answer_text.delete(1.0, tk.END)
        self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")

    def get_specific_answer(self, question, answer_type):
        # 使用模糊匹配查找最相似的问题
        best_match = None
        best_ratio = 0.0
        for item in self.data:
            ratio = SequenceMatcher(None, question, item['question']).ratio()
            if ratio > best_ratio:
                best_ratio = ratio
                best_match = item

        if best_match:
            if answer_type == "羲和回答":
                return best_match['human_answers'][0]
            else:
                return best_match['chatgpt_answers'][0]
        return "这个我也不清楚,你问问零吧"

    def load_data(self, file_path):
        data = []
        if file_path.endswith('.jsonl'):
            with jsonlines.open(file_path) as reader:
                for i, item in enumerate(reader):
                    try:
                        data.append(item)
                    except jsonlines.jsonlines.InvalidLineError as e:
                        logging.warning(f"跳过无效行 {i + 1}: {e}")
        elif file_path.endswith('.json'):
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError as e:
                    logging.warning(f"跳过无效文件 {file_path}: {e}")
        return data

    def load_model(self):
        model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
        os.makedirs(os.path.dirname(model_path), exist_ok=True)  # 确保模型目录存在
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
            logging.info("加载现有模型")
        else:
            logging.info("没有找到现有模型,将使用预训练模型")

    def train_model(self, retrain=False):
        file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])
        if not file_path:
            messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")
            return

        try:
            dataset = XihuaDataset(file_path, self.tokenizer)
            data_loader = get_data_loader(file_path, self.tokenizer, batch_size=8, max_length=128, distributed=self.multi_gpu_var.get(), num_workers=4)
            
            # 加载已训练的模型权重
            if retrain:
                self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))
                self.model.to(self.device)
                self.model.train()

            optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)
            criterion = torch.nn.BCEWithLogitsLoss()
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
            scaler = torch.amp.GradScaler('cuda')

            # 启用梯度检查点
            self.model.bert.gradient_checkpointing_enable()

            max_memory = torch.cuda.get_device_properties(self.device).total_memory * 0.9 if torch.cuda.is_available() else float('inf')
            batch_size = get_max_batch_size(self.model, self.device, max_memory)
            logging.info(f'Using batch size: {batch_size}')
            data_loader = get_data_loader(file_path, self.tokenizer, batch_size=batch_size, max_length=128, distributed=self.multi_gpu_var.get(), num_workers=4)

            writer = SummaryWriter(log_dir=os.path.join(PROJECT_ROOT, 'logs/tensorboard'))
            best_loss = float('inf')
            best_model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')
            os.makedirs(os.path.dirname(best_model_path), exist_ok=True)  # 确保模型目录存在

            for epoch in range(3):
                self.progress_bar['value'] = (epoch + 1) / 3 * 100
                self.root.update_idletasks()
                train_loss = train(self.model, data_loader, optimizer, criterion, self.device, scaler, 2)
                logging.info(f'Epoch [{epoch+1}/3], Loss: {train_loss:.4f}')
                self.log_text.insert(tk.END, f'Epoch [{epoch+1}/3], Loss: {train_loss:.4f}\n')
                self.log_text.yview(tk.END)
                writer.add_scalar('Training Loss', train_loss, epoch)
                scheduler.step(train_loss)

                if train_loss < best_loss:
                    best_loss = train_loss
                    torch.save(self.model.state_dict(), best_model_path)
                    logging.info(f"模型在 Epoch {epoch+1} 更新,Loss: {train_loss:.4f}")

            logging.info("模型训练完成并保存")
            messagebox.showinfo("训练完成", "模型训练完成并保存")
        except Exception as e:
            logging.error(f"模型训练失败: {e}")
            messagebox.showerror("训练失败", f"模型训练失败: {e}")

# 主函数
if __name__ == "__main__":
    # 启动GUI
    root = tk.Tk()
    app = XihuaChatbotGUI(root)
    root.mainloop()

关键改进点:

确保路径处理的正确性:

在 setup_logging、load_model 和 train_model 中,确保路径始终使用 os.path.join 来构建,以避免路径格式问题。

日志目录和模型目录的创建:

在 setup_logging 和 load_model 中,确保日志目录和模型目录存在,使用 os.makedirs 创建目录。

日志文件路径的格式:

确保日志文件路径的格式正确,使用 datetime.now().strftime 来生成日志文件名。

这些改进应该能解决你遇到的路径问题,并确保日志和模型文件能够正确保存。如果有其他问题或需要进一步的帮助,请随时告诉我!

相关推荐
退休钓鱼选手14 分钟前
[ Pytorch教程 ] 神经网络的基本骨架 torch.nn -Neural Network
pytorch·深度学习·神经网络
Java面试题总结1 小时前
基于 Java 的 PDF 文本水印实现方案(iText7 示例)
java·python·pdf
不懒不懒1 小时前
【决策树算法实战指南:从原理到Python实现】
python·决策树·id3·c4.5·catr
马猴烧酒.1 小时前
【面试八股|Java集合】Java集合常考面试题详解
java·开发语言·python·面试·八股
DeniuHe1 小时前
用 PyTorch 库创建了一个随机张量,并演示了多种张量取整和分解操作
pytorch
天空属于哈夫克31 小时前
Java 版:利用外部群 API 实现自动“技术开课”倒计时提醒
数据库·python·mysql
喵手1 小时前
Python爬虫实战:全站 Sitemap 自动发现 - 解析 sitemap.xml → 自动生成抓取队列的工业级实现!
爬虫·python·爬虫实战·零基础python爬虫教学·sitemap·解析sitemap.xml·自动生成抓取队列实现
luoluoal1 小时前
基于深度学习的web端多格式纠错系统(源码+文档)
python·mysql·django·毕业设计·源码
深蓝海拓2 小时前
PySide6从0开始学习的笔记(二十七) 日志管理
笔记·python·学习·pyqt
天天进步20152 小时前
Python全栈项目:实时数据处理平台
开发语言·python