项目概述
垃圾邮件过滤是机器学习在自然语言处理领域的经典应用之一。本项目将从零开始构建一个完整的垃圾邮件过滤系统,涵盖数据处理、模型训练、后端API开发到前端界面的全栈实现。
技术栈
后端技术
- Python 3.8+: 主要开发语言
- Flask: 轻量级Web框架
- scikit-learn: 机器学习库
- pandas & numpy: 数据处理
- nltk: 自然语言处理
前端技术
- HTML/CSS/JavaScript: 基础前端技术
- Bootstrap: UI框架
- Axios: HTTP客户端
数据库
- SQLite: 轻量级数据库,存储邮件记录
项目架构
spam-filter/
├── backend/
│ ├── app.py # Flask应用主文件
│ ├── model.py # 机器学习模型
│ ├── preprocessor.py # 数据预处理
│ └── database.py # 数据库操作
├── frontend/
│ ├── index.html # 主页面
│ ├── style.css # 样式文件
│ └── script.js # 前端逻辑
├── models/
│ └── spam_classifier.pkl # 训练好的模型
├── data/
│ └── emails.csv # 训练数据集
└── requirements.txt # 依赖包
核心功能实现
1. 数据预处理模块
文本预处理是提高模型性能的关键步骤,主要包括文本清洗、分词、去除停用词等操作。
import re
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
class TextPreprocessor:
def __init__(self):
nltk.download('stopwords', quiet=True)
self.stop_words = set(stopwords.words('english'))
self.stemmer = PorterStemmer()
def clean_text(self, text):
# 转换为小写
text = text.lower()
# 移除URL
text = re.sub(r'http\S+|www\S+', '', text)
# 移除邮箱地址
text = re.sub(r'\S+@\S+', '', text)
# 只保留字母和空格
text = re.sub(r'[^a-zA-Z\s]', '', text)
# 移除多余空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def preprocess(self, text):
# 清洗文本
text = self.clean_text(text)
# 分词
words = text.split()
# 去除停用词并进行词干提取
words = [self.stemmer.stem(word) for word in words
if word not in self.stop_words]
return ' '.join(words)
2. 机器学习模型训练
我们使用朴素贝叶斯算法和TF-IDF特征提取来构建分类器。朴素贝叶斯算法在文本分类任务中表现优异,且训练速度快。
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
import pickle
class SpamClassifier:
def __init__(self):
self.vectorizer = TfidfVectorizer(max_features=3000)
self.model = MultinomialNB()
self.preprocessor = TextPreprocessor()
def train(self, data_path):
# 加载数据
df = pd.read_csv(data_path)
# 预处理文本
df['processed_text'] = df['text'].apply(
self.preprocessor.preprocess
)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
df['processed_text'],
df['label'],
test_size=0.2,
random_state=42
)
# TF-IDF特征提取
X_train_tfidf = self.vectorizer.fit_transform(X_train)
X_test_tfidf = self.vectorizer.transform(X_test)
# 训练模型
self.model.fit(X_train_tfidf, y_train)
# 评估模型
y_pred = self.model.predict(X_test_tfidf)
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred))
return accuracy
def predict(self, text):
# 预处理
processed_text = self.preprocessor.preprocess(text)
# 特征提取
text_tfidf = self.vectorizer.transform([processed_text])
# 预测
prediction = self.model.predict(text_tfidf)[0]
probability = self.model.predict_proba(text_tfidf)[0]
return {
'is_spam': bool(prediction),
'confidence': float(max(probability))
}
def save_model(self, path):
with open(path, 'wb') as f:
pickle.dump({
'vectorizer': self.vectorizer,
'model': self.model,
'preprocessor': self.preprocessor
}, f)
def load_model(self, path):
with open(path, 'rb') as f:
data = pickle.load(f)
self.vectorizer = data['vectorizer']
self.model = data['model']
self.preprocessor = data['preprocessor']
3. Flask后端API
后端提供RESTful API接口,处理邮件分类请求和历史记录查询。
from flask import Flask, request, jsonify
from flask_cors import CORS
import sqlite3
from datetime import datetime
app = Flask(__name__)
CORS(app)
# 加载训练好的模型
classifier = SpamClassifier()
classifier.load_model('models/spam_classifier.pkl')
# 数据库初始化
def init_db():
conn = sqlite3.connect('emails.db')
c = conn.cursor()
c.execute('''
CREATE TABLE IF NOT EXISTS emails (
id INTEGER PRIMARY KEY AUTOINCREMENT,
subject TEXT,
content TEXT,
is_spam INTEGER,
confidence REAL,
timestamp TEXT
)
''')
conn.commit()
conn.close()
init_db()
@app.route('/api/classify', methods=['POST'])
def classify_email():
try:
data = request.json
subject = data.get('subject', '')
content = data.get('content', '')
# 组合主题和内容
full_text = f"{subject} {content}"
# 预测
result = classifier.predict(full_text)
# 保存到数据库
conn = sqlite3.connect('emails.db')
c = conn.cursor()
c.execute('''
INSERT INTO emails (subject, content, is_spam, confidence, timestamp)
VALUES (?, ?, ?, ?, ?)
''', (
subject,
content,
int(result['is_spam']),
result['confidence'],
datetime.now().isoformat()
))
conn.commit()
conn.close()
return jsonify({
'success': True,
'result': result
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/history', methods=['GET'])
def get_history():
try:
limit = request.args.get('limit', 50, type=int)
conn = sqlite3.connect('emails.db')
c = conn.cursor()
c.execute('''
SELECT id, subject, content, is_spam, confidence, timestamp
FROM emails
ORDER BY timestamp DESC
LIMIT ?
''', (limit,))
rows = c.fetchall()
conn.close()
history = []
for row in rows:
history.append({
'id': row[0],
'subject': row[1],
'content': row[2],
'is_spam': bool(row[3]),
'confidence': row[4],
'timestamp': row[5]
})
return jsonify({
'success': True,
'history': history
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/stats', methods=['GET'])
def get_stats():
try:
conn = sqlite3.connect('emails.db')
c = conn.cursor()
c.execute('SELECT COUNT(*) FROM emails')
total = c.fetchone()[0]
c.execute('SELECT COUNT(*) FROM emails WHERE is_spam = 1')
spam_count = c.fetchone()[0]
conn.close()
return jsonify({
'success': True,
'stats': {
'total': total,
'spam': spam_count,
'ham': total - spam_count,
'spam_rate': spam_count / total if total > 0 else 0
}
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
if __name__ == '__main__':
app.run(debug=True, port=5000)
4. 前端界面实现
前端提供简洁友好的用户界面,支持邮件分类和历史记录查看。
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>垃圾邮件过滤系统</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
<style>
body {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.main-container {
max-width: 900px;
margin: 0 auto;
}
.card {
border-radius: 15px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
.result-spam {
background-color: #dc3545;
color: white;
}
.result-ham {
background-color: #28a745;
color: white;
}
</style>
</head>
<body>
<div class="main-container">
<h1 class="text-center text-white mb-4">🛡️ 垃圾邮件过滤系统</h1>
<!-- 统计信息 -->
<div class="row mb-4">
<div class="col-md-4">
<div class="card text-center">
<div class="card-body">
<h5>总邮件数</h5>
<h2 id="totalCount">0</h2>
</div>
</div>
</div>
<div class="col-md-4">
<div class="card text-center">
<div class="card-body">
<h5>垃圾邮件</h5>
<h2 id="spamCount" class="text-danger">0</h2>
</div>
</div>
</div>
<div class="col-md-4">
<div class="card text-center">
<div class="card-body">
<h5>正常邮件</h5>
<h2 id="hamCount" class="text-success">0</h2>
</div>
</div>
</div>
</div>
<!-- 邮件分类表单 -->
<div class="card mb-4">
<div class="card-body">
<h3 class="card-title">检测邮件</h3>
<form id="emailForm">
<div class="mb-3">
<label class="form-label">邮件主题</label>
<input type="text" class="form-control" id="subject" required>
</div>
<div class="mb-3">
<label class="form-label">邮件内容</label>
<textarea class="form-control" id="content" rows="5" required></textarea>
</div>
<button type="submit" class="btn btn-primary w-100">分析邮件</button>
</form>
<!-- 结果显示 -->
<div id="result" class="mt-3" style="display:none;">
<div class="alert" id="resultAlert">
<h4 id="resultText"></h4>
<p id="confidenceText"></p>
</div>
</div>
</div>
</div>
<!-- 历史记录 -->
<div class="card">
<div class="card-body">
<h3 class="card-title">检测历史</h3>
<div id="history" class="table-responsive">
<table class="table">
<thead>
<tr>
<th>时间</th>
<th>主题</th>
<th>结果</th>
<th>置信度</th>
</tr>
</thead>
<tbody id="historyBody"></tbody>
</table>
</div>
</div>
</div>
</div>
<script src="https://cdn.jsdelivr.net/npm/axios/dist/axios.min.js"></script>
<script>
const API_URL = 'http://localhost:5000/api';
// 加载统计信息
async function loadStats() {
const response = await axios.get(`${API_URL}/stats`);
const stats = response.data.stats;
document.getElementById('totalCount').textContent = stats.total;
document.getElementById('spamCount').textContent = stats.spam;
document.getElementById('hamCount').textContent = stats.ham;
}
// 加载历史记录
async function loadHistory() {
const response = await axios.get(`${API_URL}/history?limit=10`);
const history = response.data.history;
const tbody = document.getElementById('historyBody');
tbody.innerHTML = '';
history.forEach(item => {
const row = tbody.insertRow();
const time = new Date(item.timestamp).toLocaleString('zh-CN');
row.innerHTML = `
<td>${time}</td>
<td>${item.subject}</td>
<td><span class="badge ${item.is_spam ? 'bg-danger' : 'bg-success'}">
${item.is_spam ? '垃圾邮件' : '正常邮件'}
</span></td>
<td>${(item.confidence * 100).toFixed(2)}%</td>
`;
});
}
// 提交表单
document.getElementById('emailForm').addEventListener('submit', async (e) => {
e.preventDefault();
const subject = document.getElementById('subject').value;
const content = document.getElementById('content').value;
try {
const response = await axios.post(`${API_URL}/classify`, {
subject: subject,
content: content
});
const result = response.data.result;
const resultDiv = document.getElementById('result');
const resultAlert = document.getElementById('resultAlert');
const resultText = document.getElementById('resultText');
const confidenceText = document.getElementById('confidenceText');
resultDiv.style.display = 'block';
if (result.is_spam) {
resultAlert.className = 'alert result-spam';
resultText.textContent = '⚠️ 这是一封垃圾邮件!';
} else {
resultAlert.className = 'alert result-ham';
resultText.textContent = '✅ 这是一封正常邮件';
}
confidenceText.textContent = `置信度: ${(result.confidence * 100).toFixed(2)}%`;
// 刷新统计和历史
loadStats();
loadHistory();
} catch (error) {
alert('分类失败: ' + error.message);
}
});
// 页面加载时初始化
loadStats();
loadHistory();
</script>
</body>
</html>
模型优化技巧
1. 特征工程优化
可以添加更多特征来提升模型性能,包括邮件长度、特殊字符比例、大写字母比例、数字比例等。
def extract_features(text):
features = {}
features['length'] = len(text)
features['capital_ratio'] = sum(1 for c in text if c.isupper()) / len(text)
features['digit_ratio'] = sum(1 for c in text if c.isdigit()) / len(text)
features['special_char_ratio'] = sum(1 for c in text if not c.isalnum()) / len(text)
return features
2. 集成学习
可以使用多个分类器进行投票,提高预测准确性。
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
ensemble = VotingClassifier(
estimators=[
('nb', MultinomialNB()),
('lr', LogisticRegression()),
('svc', SVC(probability=True))
],
voting='soft'
)
3. 超参数调优
使用网格搜索找到最佳参数组合。
from sklearn.model_selection import GridSearchCV
param_grid = {
'alpha': [0.1, 0.5, 1.0, 2.0],
'fit_prior': [True, False]
}
grid_search = GridSearchCV(
MultinomialNB(),
param_grid,
cv=5,
scoring='accuracy'
)
部署方案
本地部署
直接运行Flask应用即可:
python backend/app.py
Docker部署
创建Dockerfile实现容器化部署:
FROM python:3.9-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
EXPOSE 5000
CMD ["python", "app.py"]
云平台部署
可以部署到Heroku、AWS、阿里云等平台,需要注意配置环境变量和数据库连接。
项目扩展方向
- 多语言支持: 扩展到中文垃圾邮件检测
- 实时监控: 添加邮件监控功能,自动过滤收件箱
- 深度学习: 使用LSTM或BERT等深度学习模型提升性能
- 用户反馈: 允许用户标注错误分类,持续优化模型
- 可视化分析: 添加词云、特征重要性等可视化功能
总结
本项目完整展示了从数据处理、模型训练到Web应用开发的全流程。通过这个项目,你可以掌握机器学习在实际场景中的应用方法,以及前后端开发的基本技能。项目代码简洁清晰,适合作为学习Python全栈开发的入门项目。
在实际应用中,还需要考虑模型的持续更新、系统的可扩展性和安全性等问题。随着垃圾邮件技术的不断演进,模型也需要定期重新训练以保持良好的检测效果。
项目代码: