用三个模型洞察大模型NLP的基础能力
一、项目概述
在这个基于AI构建AI的思维探索项目中,我们实现了一个基于BERT的中文AI助手系统。该系统集成了文本分类、命名实体识别和知识库管理等功能,深入了解本项目可以让读者充分了解AI大模型训练和推理的基本原理,该项目使用了三个基础大模型:bert-base-chinese,ckiplab/bert-base-chinese-ner``,spacy.lang.zh.Chinese
旨在展示如何构建一个基础的AI应用系统。

1.1 技术栈
- 后端:Python + Flask
- 前端:HTML + JavaScript
- 模型:BERT中文预训练模型
- 框架:Transformers + PyTorch
1.2 核心功能
- 智能问答
- 知识管理
- 模型训练
- 实体识别
1.3 功能分析
这个应用程序试图实现一个简单的AI助手系统,包含以下功能:
-
知识管理
- 通过
KnowledgeBase
类管理知识库 - 支持添加和检索知识
- 知识以JSON格式持久化存储
- 通过
-
文本分类
- 使用
bert-base-chinese
模型 - 只能进行简单的二分类(0或1)
- 主要用于判断问题是否与AI相关
- 使用
-
命名实体识别
- 使用
ckiplab/bert-base-chinese-ner
模型 - 识别文本中的人名、地名等实体
- 结果并未充分利用
- 使用
-
Web界面
- 提供聊天界面
- 支持知识添加
- 支持模型训练
1.4 对大模型技术的帮助
这个项目对学习大模型技术有以下帮助:
-
基础概念学习
- 了解如何加载和使用预训练模型
- 理解模型训练的基本流程
- 掌握文本分类和NER的基本概念
-
实践价值
- 展示了如何构建一个完整的AI应用
- 演示了模型训练和预测的基本流程
- 展示了如何处理中文文本
-
改进方向
- 可以学习如何优化模型训练
- 了解如何更好地整合多个模型
- 学习如何设计更好的系统架构
二、系统架构
基于Python3.11.5虚拟环境,完整源代码见3.4。
2.1 核心模块
python
class KnowledgeBase:
"""知识库管理模块"""
def __init__(self):
self.knowledge = {}
self.load_knowledge()
class LearningModule:
"""模型学习模块"""
def __init__(self):
self.model_name = "bert-base-chinese"
self.model_path = "./trained_model"
class ThinkingFramework:
"""思维框架模块"""
def __init__(self):
self.cognitive_model = Chinese()
self.ner_tagger = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
2.2 模块职责
- 知识库模块:管理系统的知识存储和检索
- 学习模块:处理模型训练和预测
- 思维框架:处理文本分析和实体识别
- 用户界面:提供Web交互界面
三、关键技术实现
3.1 模型加载与训练
python
def load_or_train_model(self):
"""加载或训练模型"""
try:
if os.path.exists(self.model_path):
self.model = AutoModelForTokenClassification.from_pretrained(self.model_path)
else:
self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
self.train_initial_data()
except Exception as e:
print(f"加载模型时出错: {str(e)}")
raise
3.2 知识检索实现
python
def retrieve(self, query):
"""检索与查询相关的知识"""
results = []
query_lower = query.lower()
for k, v in self.knowledge.items():
k_lower = k.lower()
if query_lower == k_lower:
results.append((k, v, 1.0)) # 完全匹配
elif query_lower in k_lower or k_lower in query_lower:
overlap = len(set(query_lower) & set(k_lower)) / max(len(query_lower), len(k_lower))
results.append((k, v, overlap)) # 部分匹配
3.3 实体识别实现
python
def get_ner(self, tokens):
"""获取命名实体识别结果"""
try:
tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
inputs = tokenizer(text, return_tensors="pt")
outputs = self.ner_tagger(**inputs)
predictions = outputs.logits.argmax(dim=2).tolist()[0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# 提取实体
entities = []
current_entity = []
current_type = None
for token, pred in zip(tokens, predictions):
if pred > 0: # 如果不是O标签
if not current_entity:
current_entity = [token]
current_type = pred
else:
current_entity.append(token)
3.4 完整源代码
系统运行截图:
代码文件1:main.py
bash
#运行成功
import os
import json
from flask import Flask, request, jsonify, render_template
from transformers import AutoTokenizer, AutoModelForTokenClassification
from spacy.lang.zh import Chinese
import numpy as np
import torch
import time
import torch.nn.functional as F
# 设置环境变量,使用镜像源
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
class KnowledgeBase:
def __init__(self):
self.knowledge = {}
self.load_knowledge() # 启动时加载知识
def add(self, key, value):
self.knowledge[key] = value
def retrieve(self, query):
#return [k for k, v in self.knowledge.items() if query.lower() in k.lower()]
"""检索与查询相关的知识
Args:
query: 用户查询字符串
Returns:
包含(key, value)元组的列表,按相关性排序
"""
results = []
query_lower = query.lower()
# 首先检查完全匹配
for k, v in self.knowledge.items():
k_lower = k.lower()
if query_lower == k_lower:
results.append((k, v, 1.0)) # 完全匹配,相关性为1.0
elif query_lower in k_lower or k_lower in query_lower:
# 部分匹配,计算相关性分数
overlap = len(set(query_lower) & set(k_lower)) / max(len(query_lower), len(k_lower))
results.append((k, v, overlap))
else:
# 检查内容中是否包含查询词
if query_lower in v.lower():
results.append((k, v, 0.5)) # 内容匹配,相关性为0.5
# 按相关性排序
results.sort(key=lambda x: x[2], reverse=True)
# 只返回相关性大于0.3的结果
filtered_results = [(k, v) for k, v, score in results if score > 0.3]
return filtered_results
def save_knowledge(self):
with open("knowledge_base.json", "w", encoding="utf-8") as f:
json.dump(self.knowledge, f, ensure_ascii=False, indent=4)
def load_knowledge(self):
try:
with open("knowledge_base.json", "r", encoding="utf-8") as f:
self.knowledge = json.load(f)
except FileNotFoundError:
self.knowledge = {}
class LearningModule:
def __init__(self):
self.model_name = "bert-base-chinese"
self.model_path = "./trained_model" # 修改为本地路径
self.optimizer = None # 添加优化器初始化
self.initial_training_data = [ # 添加初始训练数据
("我是一个AI助手", 1),
("Python是编程语言", 0),
("机器学习是人工智能的分支", 1),
("深度学习使用神经网络", 1),
("自然语言处理很重要", 1),
("计算机视觉处理图像", 1),
("强化学习用于决策", 1),
("数据科学使用Python", 0),
("Web开发需要HTML", 0),
("数据库存储数据", 0)
]
self.training_data = self.initial_training_data.copy() # 添加训练数据存储
self.load_or_train_model()
def save_model(self):
"""保存模型到本地目录"""
try:
# 确保目录存在
os.makedirs(self.model_path, exist_ok=True)
# 保存模型和分词器
self.model.save_pretrained(self.model_path)
self.tokenizer.save_pretrained(self.model_path)
print(f"模型已保存到: {self.model_path}")
except Exception as e:
print(f"保存模型时出错: {str(e)}")
def load_or_train_model(self):
"""加载或训练模型"""
try:
# 首先尝试从本地加载模型
if os.path.exists(self.model_path):
print(f"从本地加载模型: {self.model_path}")
self.model = AutoModelForTokenClassification.from_pretrained(self.model_path)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
else:
# 如果本地没有模型,从预训练模型加载
print(f"从预训练模型加载: {self.model_name}")
self.model = AutoModelForTokenClassification.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# 初始化优化器
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
# 训练初始数据
self.train_initial_data()
except Exception as e:
print(f"加载模型时出错: {str(e)}")
raise
def train_initial_data(self):
"""使用初始数据进行训练"""
sentences, labels = zip(*self.initial_training_data)
self.train(list(sentences), list(labels))
def train(self, sentences, labels):
"""训练模型"""
try:
# 对所有句子一起进行tokenize,确保填充到相同长度
inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
labels_ids = torch.tensor(labels).long()
# 设置模型为训练模式
self.model.train()
# 前向传播
outputs = self.model(input_ids, attention_mask=attention_mask)
# 调整输出和标签的形状以匹配交叉熵损失函数的要求
# 假设我们只关心[CLS]标记的输出(序列的第一个标记)
logits = outputs.logits[:, 0, :] # 取每个序列的第一个标记的输出
loss = F.cross_entropy(logits, labels_ids)
# 反向传播和优化
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
print(f"训练完成,损失: {loss.item():.4f}")
return loss.item()
except Exception as e:
print(f"训练过程中出错: {str(e)}")
import traceback
traceback.print_exc()
return None
def predict(self, sentence):
"""对输入句子进行预测"""
try:
# 对输入进行tokenize
tokenized = self.tokenizer(sentence, return_tensors="pt", padding=True)
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
# 获取模型输出
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask)
# 处理输出
logits = outputs.logits
predictions = torch.argmax(logits, dim=2).tolist()[0]
# 这里可以添加更多处理逻辑,例如将预测结果映射到标签
# 简单返回预测结果
return f"预测类别: {predictions[0]}"
except Exception as e:
print(f"预测错误: {str(e)}")
return "无法进行预测"
def train_model(self, sentence, label):
# 保存训练数据
self.training_data.append((sentence, label))
# 训练模型
loss = self.train([sentence], [label])
# 保存模型
self.save_model()
return loss
class ThinkingFramework:
def __init__(self):
self.cognitive_model = Chinese()
try:
print("正在加载NER模型...")
self.ner_tagger = AutoModelForTokenClassification.from_pretrained("ckiplab/bert-base-chinese-ner")
print("NER模型加载成功!")
except Exception as e:
print(f"NER模型加载失败: {str(e)}")
print("请确保网络连接正常,并已安装所有必要的依赖。")
raise
def parse(self, text):
return self.cognitive_model(text)
def get_ner(self, tokens):
"""获取命名实体识别结果"""
try:
if isinstance(tokens, str):
tokens = self.parse(tokens)
tokenizer = AutoTokenizer.from_pretrained("ckiplab/bert-base-chinese-ner")
text = tokens.text if hasattr(tokens, 'text') else str(tokens)
inputs = tokenizer(text, return_tensors="pt")
outputs = self.ner_tagger(**inputs)
# 获取预测结果和标签
predictions = outputs.logits.argmax(dim=2).tolist()[0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# 提取实体
entities = []
current_entity = []
current_type = None
for token, pred in zip(tokens, predictions):
if pred > 0: # 如果不是O标签
if not current_entity:
current_entity = [token]
current_type = pred
else:
current_entity.append(token)
elif current_entity:
entities.append((''.join(current_entity), current_type))
current_entity = []
current_type = None
return entities
except Exception as e:
print(f"NER处理错误: {str(e)}")
return []
class SelfConsciousness:
def __init__(self):
self.knowledge = {}
self.last_update = time.time()
def update(self, new_knowledge):
if new_knowledge:
self.knowledge.update({k: time.time() for k, _ in new_knowledge})
self.last_update = time.time()
print(f"自我意识已更新,新增知识点: {len(new_knowledge)}")
def save(self):
with open("consciousness.json", "w", encoding="utf-8") as f:
json.dump({
"knowledge": self.knowledge,
"last_update": self.last_update
}, f, ensure_ascii=False, indent=4)
print("自我意识已保存")
class UserInterface:
def __init__(self, thinking_framework, learning_module, knowledge_base):
self.thinking_framework = thinking_framework
self.learning_module = learning_module
self.knowledge_base = knowledge_base
self.app = Flask(__name__)
self.setup_routes()
def setup_routes(self):
@self.app.route('/')
def index():
return render_template('index.html')
@self.app.route('/chat', methods=['POST'])
def chat():
data = request.get_json()
user_message = data.get('message', '')
if not user_message:
return jsonify({'response': '请输入有效的问题。'})
try:
# 使用思维框架处理用户输入
doc = self.thinking_framework.parse(user_message)
# 正确处理NER标签
try:
# 首先解析文本,然后获取NER标签
tokens = self.thinking_framework.parse(user_message)
ner_tags = self.thinking_framework.get_ner(tokens)
except Exception as e:
print(f"NER处理错误: {str(e)}")
ner_tags = []
# 使用学习模块进行预测
try:
prediction = self.learning_module.predict(user_message)
except Exception as e:
print(f"预测错误: {str(e)}")
prediction = None
# 从知识库中检索相关信息
try:
relevant_knowledge = self.knowledge_base.retrieve(user_message)
print(f"检索到的知识: {relevant_knowledge}")
except Exception as e:
print(f"知识库检索错误: {str(e)}")
relevant_knowledge = []
# 生成回复
response = self.generate_response(user_message, doc, ner_tags, prediction, relevant_knowledge)
return jsonify({'response': response})
except Exception as e:
print(f"处理消息时出错: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({'response': f'抱歉,处理您的消息时出现了错误: {str(e)}'})
@self.app.route('/add_knowledge', methods=['POST'])
def add_knowledge():
data = request.get_json()
key = data.get('key', '')
value = data.get('value', '')
if not key or not value:
return jsonify({'success': False, 'message': '请提供有效的知识关键词和内容'})
try:
# 添加知识到知识库
self.knowledge_base.add(key, value)
print(f"添加知识: {key} -> {value}")
return jsonify({'success': True})
except Exception as e:
print(f"添加知识时出错: {str(e)}")
return jsonify({'success': False, 'message': str(e)})
@self.app.route('/train_model', methods=['POST'])
def train_model():
data = request.get_json()
sentence = data.get('sentence', '')
label = data.get('label')
if not sentence or label is None:
return jsonify({'success': False, 'message': '请提供有效的训练句子和标签'})
try:
# 训练模型
loss = self.learning_module.train_model(sentence, label)
if loss is not None:
print(f"训练模型: '{sentence}' -> {label}, 损失: {loss}")
return jsonify({'success': True, 'loss': loss})
else:
return jsonify({'success': False, 'message': '训练失败,请查看服务器日志'})
except Exception as e:
print(f"训练模型时出错: {str(e)}")
return jsonify({'success': False, 'message': str(e)})
def generate_response(self, user_message, doc, ner_tags, prediction, relevant_knowledge):
"""生成对用户消息的回复"""
try:
response_parts = []
# 处理NER结果
if ner_tags:
entities = [entity for entity, _ in ner_tags]
if entities:
response_parts.append(f"我注意到您提到了:{', '.join(entities)}")
# 添加知识库中的相关信息
if relevant_knowledge:
for key, value in relevant_knowledge[:2]:
response_parts.append(f"关于{key},我知道:{value}")
# 根据预测结果调整回复
if prediction and "预测类别: 1" in prediction:
response_parts.append("这看起来是一个与AI相关的问题,我很乐意帮您解答。")
# 如果没有找到相关信息
if not response_parts:
if "你好" in user_message or "您好" in user_message:
response_parts.append("您好!我是AI助手,很高兴为您服务。")
elif any(word in user_message for word in ["什么", "怎么", "如何", "为什么"]):
response_parts.append("这是一个很好的问题。让我为您详细解释...")
else:
response_parts.append("我理解您的问题,让我为您提供相关信息...")
return "\n\n".join(response_parts)
except Exception as e:
print(f"生成回复时出错: {str(e)}")
return "我正在处理您的请求,但遇到了一些技术问题。请稍后再试。"
def start(self):
# 在Windows上,使用use_reloader=False可以避免WinError 10038
self.app.run(debug=True, use_reloader=False, threaded=True)
if __name__ == "__main__":
try:
# 初始化各个组件
print("初始化知识库...")
kb = KnowledgeBase()
print("初始化学习模块...")
lm = LearningModule()
print("初始化自我意识模块...")
sf = SelfConsciousness()
print("初始化思维框架...")
tf = ThinkingFramework()
print("初始化用户界面...")
uif = UserInterface(tf, lm, kb)
print("添加初始知识...")
# 添加一些初始知识
kb.add("人工智能", "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。")
kb.add("机器学习", "机器学习是人工智能的一个分支,它使用各种算法来解析数据、学习数据,然后对新数据进行预测。")
kb.add("深度学习", "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的工作方式。")
kb.add("自然语言处理", "自然语言处理(NLP)是人工智能的一个子领域,专注于使计算机能够理解、解释和生成人类语言。")
kb.add("计算机视觉", "计算机视觉是人工智能的一个领域,专注于使计算机能够从图像或视频中获取高级理解。")
kb.add("强化学习", "强化学习是机器学习的一种方法,通过让代理在环境中采取行动并获得奖励或惩罚来学习最佳策略。")
kb.add("神经网络", "神经网络是一种受人脑结构启发的计算模型,由多层相互连接的节点(神经元)组成,用于模式识别和机器学习。")
kb.add("大语言模型", "大语言模型是一种基于深度学习的自然语言处理模型,通过大量文本数据训练,能够生成、理解和翻译人类语言。")
kb.add("Python", "Python是一种高级编程语言,以其简洁、易读的语法和丰富的库而闻名,广泛应用于数据科学、人工智能和Web开发等领域。")
print("更新知识...")
# 模拟知识更新
knowledge_items = kb.retrieve("人工智能")
if knowledge_items:
print(f"找到相关知识: {knowledge_items}")
sf.update(knowledge_items)
# 保存更新后的知识
sf.save()
else:
print("未找到相关知识,跳过更新")
print("训练模型...")
# 开始训练模型
lm.train(["我是一个AI助手", "Python是编程语言"], [1, 0])
print("启动Web服务...")
# 启动服务
uif.start()
except Exception as e:
print(f"启动失败: {str(e)}")
import traceback
traceback.print_exc()
代码文件2:index.html(放在在 templates文件夹下)
bash
<!DOCTYPE html>
<html>
<head>
<title>AI助手</title>
<meta charset="UTF-8">
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}
.container {
max-width: 1000px;
margin: 0 auto;
display: flex;
gap: 20px;
}
.chat-container {
flex: 2;
background-color: white;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
padding: 20px;
}
.training-container {
flex: 1;
background-color: white;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
padding: 20px;
}
.chat-messages {
height: 500px;
overflow-y: auto;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
margin-bottom: 20px;
}
.message {
margin: 10px 0;
padding: 10px;
border-radius: 5px;
}
.user-message {
background-color: #e3f2fd;
margin-left: 20%;
}
.ai-message {
background-color: #f5f5f5;
margin-right: 20%;
}
.input-container {
display: flex;
gap: 10px;
}
input[type="text"], textarea {
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
font-size: 16px;
}
button {
padding: 10px 20px;
background-color: #2196f3;
color: white;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 16px;
}
button:hover {
background-color: #1976d2;
}
h2 {
color: #333;
margin-top: 0;
}
.form-group {
margin-bottom: 15px;
}
label {
display: block;
margin-bottom: 5px;
font-weight: bold;
}
.status {
margin-top: 10px;
padding: 10px;
border-radius: 5px;
display: none;
}
.success {
background-color: #d4edda;
color: #155724;
display: block;
}
.error {
background-color: #f8d7da;
color: #721c24;
display: block;
}
</style>
</head>
<body>
<div class="container">
<div class="chat-container">
<h2>AI助手聊天</h2>
<div class="chat-messages" id="chat-messages"></div>
<div class="input-container">
<input type="text" id="user-input" placeholder="请输入您的问题...">
<button onclick="sendMessage()">发送</button>
</div>
</div>
<div class="training-container">
<h2>知识训练</h2>
<div class="form-group">
<label for="knowledge-key">知识关键词</label>
<input type="text" id="knowledge-key" placeholder="例如:人工智能">
</div>
<div class="form-group">
<label for="knowledge-value">知识内容</label>
<textarea id="knowledge-value" rows="5" placeholder="例如:人工智能是计算机科学的一个分支..."></textarea>
</div>
<button onclick="addKnowledge()">添加知识</button>
<div id="knowledge-status" class="status"></div>
<h2 style="margin-top: 20px;">模型训练</h2>
<div class="form-group">
<label for="train-sentence">训练句子</label>
<input type="text" id="train-sentence" placeholder="例如:我是一个AI助手">
</div>
<div class="form-group">
<label for="train-label">标签 (0 或 1)</label>
<input type="number" id="train-label" min="0" max="1" value="1">
</div>
<button onclick="trainModel()">训练模型</button>
<div id="train-status" class="status"></div>
</div>
</div>
<script>
// 添加欢迎消息
window.onload = function() {
addMessage("您好!我是AI助手,有什么可以帮助您的吗?", false);
};
function addMessage(message, isUser) {
const messagesDiv = document.getElementById('chat-messages');
const messageDiv = document.createElement('div');
messageDiv.className = `message ${isUser ? 'user-message' : 'ai-message'}`;
messageDiv.textContent = message;
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
function sendMessage() {
const input = document.getElementById('user-input');
const message = input.value.trim();
if (message) {
addMessage(message, true);
input.value = '';
// 显示加载状态
const loadingId = showLoading();
fetch('/chat', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ message: message })
})
.then(response => response.json())
.then(data => {
// 移除加载状态
removeLoading(loadingId);
addMessage(data.response, false);
})
.catch(error => {
// 移除加载状态
removeLoading(loadingId);
console.error('Error:', error);
addMessage('抱歉,发生了一些错误。请稍后再试。', false);
});
}
}
function showLoading() {
const messagesDiv = document.getElementById('chat-messages');
const loadingDiv = document.createElement('div');
loadingDiv.className = 'message ai-message';
loadingDiv.textContent = '正在思考...';
loadingDiv.id = 'loading-' + Date.now();
messagesDiv.appendChild(loadingDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
return loadingDiv.id;
}
function removeLoading(id) {
const loadingDiv = document.getElementById(id);
if (loadingDiv) {
loadingDiv.remove();
}
}
function addKnowledge() {
const key = document.getElementById('knowledge-key').value.trim();
const value = document.getElementById('knowledge-value').value.trim();
const statusDiv = document.getElementById('knowledge-status');
if (!key || !value) {
statusDiv.textContent = '请填写完整的知识信息';
statusDiv.className = 'status error';
return;
}
fetch('/add_knowledge', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ key: key, value: value })
})
.then(response => response.json())
.then(data => {
if (data.success) {
statusDiv.textContent = '知识添加成功!';
statusDiv.className = 'status success';
// 清空输入框
document.getElementById('knowledge-key').value = '';
document.getElementById('knowledge-value').value = '';
} else {
statusDiv.textContent = data.message || '知识添加失败';
statusDiv.className = 'status error';
}
})
.catch(error => {
console.error('Error:', error);
statusDiv.textContent = '发生错误,请稍后再试';
statusDiv.className = 'status error';
});
}
function trainModel() {
const sentence = document.getElementById('train-sentence').value.trim();
const label = parseInt(document.getElementById('train-label').value);
const statusDiv = document.getElementById('train-status');
if (!sentence || isNaN(label) || (label !== 0 && label !== 1)) {
statusDiv.textContent = '请填写正确的训练数据';
statusDiv.className = 'status error';
return;
}
fetch('/train_model', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ sentence: sentence, label: label })
})
.then(response => response.json())
.then(data => {
if (data.success) {
statusDiv.textContent = `训练成功!损失: ${data.loss}`;
statusDiv.className = 'status success';
// 清空输入框
document.getElementById('train-sentence').value = '';
} else {
statusDiv.textContent = data.message || '训练失败';
statusDiv.className = 'status error';
}
})
.catch(error => {
console.error('Error:', error);
statusDiv.textContent = '发生错误,请稍后再试';
statusDiv.className = 'status error';
});
}
// 按回车发送消息
document.getElementById('user-input').addEventListener('keypress', function(e) {
if (e.key === 'Enter') {
sendMessage();
}
});
</script>
</body>
</html>
四、系统特点与优势
4.1 特点
- 模块化设计:各个功能模块独立,便于维护和扩展
- 知识持久化:支持知识的保存和加载
- 实时训练:支持动态添加训练数据
- 多模型集成:结合了分类和NER模型
4.2 优势
- 易于理解:代码结构清晰,适合学习
- 功能完整:包含了AI应用的基本要素
- 可扩展性:便于添加新功能
五、存在的问题与改进方向
5.1 存在的问题
-
功能混杂
- 同时使用了多个模型,但功能不明确
- 分类模型和NER模型的结果没有很好地整合
- 知识库检索和模型预测结果没有很好地结合
-
技术实现问题
- 分类模型训练数据太少(只有10个样本)
- 模型训练过程过于简单
- 没有充分利用BERT模型的优势
-
架构设计问题
- 各个模块之间耦合度高
- 错误处理不够完善
- 缺乏完整的测试机制
5.2 改进方向
- 扩充训练数据
python
def add_training_data(self, sentences, labels):
"""添加更多训练数据"""
self.training_data.extend(zip(sentences, labels))
self.train(list(sentences), list(labels))
- 优化模型训练
python
def train(self, sentences, labels):
"""优化训练过程"""
self.model.train()
for epoch in range(self.num_epochs):
# 添加学习率调度
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer)
# 添加早停机制
if self.should_early_stop():
break
- 改进知识检索
python
def retrieve(self, query):
"""改进知识检索算法"""
# 添加语义相似度计算
# 使用向量检索
# 添加缓存机制
六、总结与展望
6.1 项目总结
这个项目展示了如何构建一个基础的AI助手系统,并非一个实用的生产系统, 但作为学习案例具有很好的参考价值,它展示了:
-
基础架构
- 如何组织AI应用的基本结构
- 如何处理用户输入和输出
- 如何管理数据和模型
-
技术要点
- 预训练模型的使用
- 基本的模型训练流程
- 文本处理的基本方法
-
改进空间
- 模型训练需要优化
- 系统架构需要重构
- 功能需要聚焦
对于想要学习大模型技术的开发者来说,这个项目可以作为一个起点,但需要在此基础上进行大量的改进和优化才能成为一个实用的系统。
6.2 未来展望
- 引入更先进的模型架构
- 优化系统性能
- 增加更多实用功能
- 改进用户体验
七、参考资料
- BERT论文:Attention Is All You Need
- Transformers库文档
- PyTorch官方文档
- Flask Web框架文档
这篇博文全面介绍了项目的实现过程、技术细节和未来展望,适合作为技术学习资料。您可以根据需要调整内容的深度和广度。