【基于xGBoost的钓鱼邮件智能识别与拦截系统】

包含邮件特征提取、xgboost 训练、实时判定、拦截规则和简单 API。

python 复制代码
import re
import os
import json
import joblib
import numpy as np
import pandas as pd
from typing import Dict, List, Any
from dataclasses import dataclass, asdict
from urllib.parse import urlparse

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.sparse import hstack, csr_matrix

from xgboost import XGBClassifier
from flask import Flask, request, jsonify


# =========================
# 1. 配置
# =========================
MODEL_DIR = "artifacts"
MODEL_PATH = os.path.join(MODEL_DIR, "xgb_phishing_model.joblib")
TFIDF_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.joblib")
FEATURE_META_PATH = os.path.join(MODEL_DIR, "feature_meta.json")

os.makedirs(MODEL_DIR, exist_ok=True)

SUSPICIOUS_WORDS = [
    "urgent", "verify", "suspend", "password", "bank", "invoice", "click",
    "limited", "winner", "login", "confirm", "security", "alert", "reset",
    "account", "payment", "free", "prize", "immediately", "risk"
]

TRUSTED_DOMAINS = {
    "google.com", "microsoft.com", "apple.com", "amazon.com", "github.com",
    "openai.com", "outlook.com", "qq.com", "163.com"
}


# =========================
# 2. 数据结构
# =========================
@dataclass
class EmailSample:
    sender: str
    subject: str
    body: str
    urls: List[str]
    has_attachment: int = 0
    attachment_count: int = 0
    label: int = 0  # 1=phishing, 0=legit


# =========================
# 3. 特征工程
# =========================
class EmailFeatureExtractor:
    def __init__(self):
        self.tfidf = TfidfVectorizer(
            max_features=3000,
            ngram_range=(1, 2),
            stop_words="english"
        )
        self.fitted = False

    @staticmethod
    def extract_urls(text: str) -> List[str]:
        if not text:
            return []
        pattern = r"https?://[^\s<>"]+|www\.[^\s<>"]+"
        return re.findall(pattern, text.lower())

    @staticmethod
    def sender_domain(sender: str) -> str:
        match = re.search(r'@([A-Za-z0-9.-]+)', sender or "")
        return match.group(1).lower() if match else "unknown"

    @staticmethod
    def url_domain(url: str) -> str:
        try:
            if not url.startswith("http"):
                url = "http://" + url
            return urlparse(url).netloc.lower().replace("www.", "")
        except Exception:
            return "invalid"

    def numeric_features(self, email: EmailSample) -> np.ndarray:
        text = f"{email.subject} {email.body}".lower()
        sender_domain = self.sender_domain(email.sender)
        urls = email.urls if email.urls else self.extract_urls(text)
        url_domains = [self.url_domain(u) for u in urls]

        suspicious_word_count = sum(1 for w in SUSPICIOUS_WORDS if w in text)
        exclamation_count = text.count("!")
        digit_count = sum(c.isdigit() for c in text)
        uppercase_ratio = (
            sum(1 for c in (email.subject + email.body) if c.isupper()) /
            max(1, len(email.subject + email.body))
        )
        url_count = len(urls)
        mismatched_domains = sum(1 for d in url_domains if d and d != sender_domain)
        untrusted_url_domains = sum(1 for d in url_domains if d not in TRUSTED_DOMAINS)
        sender_untrusted = 0 if sender_domain in TRUSTED_DOMAINS else 1
        body_len = len(email.body or "")
        subject_len = len(email.subject or "")
        has_html_hint = int("<html" in text or "href=" in text)

        return np.array([
            suspicious_word_count,
            exclamation_count,
            digit_count,
            uppercase_ratio,
            url_count,
            mismatched_domains,
            untrusted_url_domains,
            sender_untrusted,
            email.has_attachment,
            email.attachment_count,
            body_len,
            subject_len,
            has_html_hint,
        ], dtype=float)

    def _combine_text(self, email: EmailSample) -> str:
        sender_domain = self.sender_domain(email.sender)
        urls = email.urls if email.urls else self.extract_urls(email.body)
        return f"SUBJECT {email.subject} BODY {email.body} SENDER_DOMAIN {sender_domain} URLS {' '.join(urls)}"

    def fit_transform(self, emails: List[EmailSample]):
        texts = [self._combine_text(e) for e in emails]
        numeric = np.vstack([self.numeric_features(e) for e in emails])
        tfidf_matrix = self.tfidf.fit_transform(texts)
        self.fitted = True
        return hstack([tfidf_matrix, csr_matrix(numeric)])

    def transform(self, emails: List[EmailSample]):
        if not self.fitted:
            raise RuntimeError("Feature extractor is not fitted.")
        texts = [self._combine_text(e) for e in emails]
        numeric = np.vstack([self.numeric_features(e) for e in emails])
        tfidf_matrix = self.tfidf.transform(texts)
        return hstack([tfidf_matrix, csr_matrix(numeric)])


# =========================
# 4. 模型服务
# =========================
class PhishingDetector:
    def __init__(self):
        self.extractor = EmailFeatureExtractor()
        self.model = XGBClassifier(
            n_estimators=250,
            max_depth=6,
            learning_rate=0.08,
            subsample=0.9,
            colsample_bytree=0.9,
            eval_metric="logloss",
            random_state=42,
            n_jobs=4
        )
        self.is_trained = False

    def train(self, emails: List[EmailSample]):
        X = self.extractor.fit_transform(emails)
        y = np.array([e.label for e in emails])

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )

        self.model.fit(X_train, y_train)
        self.is_trained = True

        pred = self.model.predict(X_test)
        prob = self.model.predict_proba(X_test)[:, 1]

        print("=== 模型评估 ===")
        print(classification_report(y_test, pred, digits=4))
        print("ROC-AUC:", round(roc_auc_score(y_test, prob), 4))

        return {
            "classification_report": classification_report(y_test, pred, digits=4),
            "roc_auc": float(roc_auc_score(y_test, prob))
        }

    def predict_one(self, email: EmailSample) -> Dict[str, Any]:
        if not self.is_trained:
            raise RuntimeError("Model is not trained or loaded.")

        X = self.extractor.transform([email])
        proba = float(self.model.predict_proba(X)[0, 1])
        label = int(proba >= 0.5)
        action = self.decide_action(proba)
        reasons = self.explain(email)

        return {
            "is_phishing": bool(label),
            "phishing_probability": round(proba, 4),
            "action": action,
            "reasons": reasons
        }

    def decide_action(self, proba: float) -> str:
        if proba >= 0.85:
            return "block"
        if proba >= 0.60:
            return "quarantine"
        if proba >= 0.40:
            return "flag"
        return "allow"

    def explain(self, email: EmailSample) -> List[str]:
        reasons = []
        text = f"{email.subject} {email.body}".lower()
        sender_domain = self.extractor.sender_domain(email.sender)
        urls = email.urls if email.urls else self.extractor.extract_urls(text)

        if any(w in text for w in ["verify", "password", "reset", "urgent", "bank", "login"]):
            reasons.append("邮件正文或标题包含高风险诱导词")
        if sender_domain not in TRUSTED_DOMAINS:
            reasons.append("发件人域名不在可信名单中")
        if len(urls) > 0:
            reasons.append("邮件中包含链接,存在跳转风险")
            if any(self.extractor.url_domain(u) != sender_domain for u in urls):
                reasons.append("链接域名与发件人域名不一致")
        if email.has_attachment:
            reasons.append("邮件带有附件,需要进一步审查")
        if not reasons:
            reasons.append("未发现明显高风险特征")
        return reasons

    def save(self):
        if not self.is_trained:
            raise RuntimeError("Nothing to save. Train the model first.")
        joblib.dump(self.model, MODEL_PATH)
        joblib.dump(self.extractor.tfidf, TFIDF_PATH)
        with open(FEATURE_META_PATH, "w", encoding="utf-8") as f:
            json.dump({"fitted": True}, f, ensure_ascii=False, indent=2)

    def load(self):
        self.model = joblib.load(MODEL_PATH)
        self.extractor.tfidf = joblib.load(TFIDF_PATH)
        self.extractor.fitted = True
        self.is_trained = True


# =========================
# 5. 示例数据
# =========================
def build_demo_dataset() -> List[EmailSample]:
    data = [
        EmailSample(
            sender="security@paypa1-support.com",
            subject="Urgent: Verify your account immediately",
            body="Your account will be suspended. Click here to verify your password: http://paypa1-login-check.com",
            urls=["http://paypa1-login-check.com"],
            has_attachment=0,
            attachment_count=0,
            label=1,
        ),
        EmailSample(
            sender="it-support@company-alerts.net",
            subject="Password reset required",
            body="We detected unusual login activity. Reset your password now at http://secure-reset-login.net",
            urls=["http://secure-reset-login.net"],
            has_attachment=0,
            attachment_count=0,
            label=1,
        ),
        EmailSample(
            sender="ceo.office@external-consult.net",
            subject="Confidential invoice attached",
            body="Please review the attached invoice and process payment today.",
            urls=[],
            has_attachment=1,
            attachment_count=1,
            label=1,
        ),
        EmailSample(
            sender="notifications@github.com",
            subject="Your pull request has been reviewed",
            body="A reviewer left comments on your pull request. Visit https://github.com to review.",
            urls=["https://github.com"],
            has_attachment=0,
            attachment_count=0,
            label=0,
        ),
        EmailSample(
            sender="no-reply@microsoft.com",
            subject="Security info updated",
            body="Your security information was updated successfully. If this wasn't you, visit https://microsoft.com/security",
            urls=["https://microsoft.com/security"],
            has_attachment=0,
            attachment_count=0,
            label=0,
        ),
        EmailSample(
            sender="hr@company.com",
            subject="Interview schedule confirmation",
            body="Please confirm your interview slot for next Monday.",
            urls=[],
            has_attachment=0,
            attachment_count=0,
            label=0,
        ),
        EmailSample(
            sender="service@amaz0n-billing.com",
            subject="Payment failed - urgent action needed",
            body="Your payment failed. Update your account now: http://amaz0n-billing-check.com",
            urls=["http://amaz0n-billing-check.com"],
            has_attachment=0,
            attachment_count=0,
            label=1,
        ),
        EmailSample(
            sender="newsletter@openai.com",
            subject="Product update newsletter",
            body="Read about new product updates on https://openai.com/blog",
            urls=["https://openai.com/blog"],
            has_attachment=0,
            attachment_count=0,
            label=0,
        ),
    ]

    # 扩充数据,便于训练演示
    expanded = []
    for i in range(60):
        for item in data:
            expanded.append(EmailSample(**asdict(item)))
    return expanded


# =========================
# 6. 训练入口
# =========================
def train_and_save_demo_model():
    dataset = build_demo_dataset()
    detector = PhishingDetector()
    metrics = detector.train(dataset)
    detector.save()
    print("模型已保存到 artifacts/ 目录")
    return metrics


# =========================
# 7. 拦截服务 API
# =========================
app = Flask(__name__)
service_detector = PhishingDetector()

if os.path.exists(MODEL_PATH) and os.path.exists(TFIDF_PATH):
    service_detector.load()


@app.route("/health", methods=["GET"])
def health():
    return jsonify({
        "status": "ok",
        "model_loaded": service_detector.is_trained
    })


@app.route("/predict", methods=["POST"])
def predict():
    if not service_detector.is_trained:
        return jsonify({"error": "model not loaded"}), 500

    payload = request.get_json(force=True)
    email = EmailSample(
        sender=payload.get("sender", ""),
        subject=payload.get("subject", ""),
        body=payload.get("body", ""),
        urls=payload.get("urls", []),
        has_attachment=int(payload.get("has_attachment", 0)),
        attachment_count=int(payload.get("attachment_count", 0)),
        label=0,
    )

    result = service_detector.predict_one(email)
    return jsonify(result)


# =========================
# 8. 模拟邮件网关拦截流程
# =========================
def intercept_email(email_data: Dict[str, Any]) -> Dict[str, Any]:
    if not service_detector.is_trained:
        raise RuntimeError("Service model is not loaded.")

    email = EmailSample(
        sender=email_data.get("sender", ""),
        subject=email_data.get("subject", ""),
        body=email_data.get("body", ""),
        urls=email_data.get("urls", []),
        has_attachment=int(email_data.get("has_attachment", 0)),
        attachment_count=int(email_data.get("attachment_count", 0)),
        label=0,
    )

    result = service_detector.predict_one(email)

    action_map = {
        "allow": "允许投递",
        "flag": "添加风险标签后投递",
        "quarantine": "隔离到安全区等待人工复核",
        "block": "直接拦截并告警"
    }

    return {
        "email": {
            "sender": email.sender,
            "subject": email.subject,
        },
        "decision": result,
        "operation": action_map[result["action"]]
    }


# =========================
# 9. 本地测试
# =========================
if __name__ == "__main__":
    if not (os.path.exists(MODEL_PATH) and os.path.exists(TFIDF_PATH)):
        train_and_save_demo_model()
        service_detector.load()

    test_email = {
        "sender": "alert@micr0soft-reset.net",
        "subject": "Urgent account verification",
        "body": "Your account is at risk. Verify now using http://micr0soft-login-check.net",
        "urls": ["http://micr0soft-login-check.net"],
        "has_attachment": 0,
        "attachment_count": 0,
    }

    print("=== 本地拦截测试 ===")
    print(json.dumps(intercept_email(test_email), ensure_ascii=False, indent=2))

    # 启动 API
    app.run(host="0.0.0.0", port=5000, debug=False)

下面给你的是一个可运行的 Python 版"基于 XGBoost 的钓鱼邮件智能识别与拦截系统"方案,已经放到旁边的代码文档里了。它包含这些核心能力:

  1. 邮件特征提取

    • 发件人域名
    • 标题/正文文本
    • URL 数量与域名是否伪装
    • 可疑关键词
    • 大写比例、数字数量、附件数量等
  2. XGBoost 分类模型

    • 使用 TF-IDF + 数值特征 组合建模
    • 输出钓鱼概率
  3. 智能拦截策略

    • allow:放行
    • flag:风险标记
    • quarantine:隔离
    • block:直接拦截
  4. API 服务

    • POST /predict 实时判定邮件风险
    • GET /health 检查服务状态

运行前安装依赖

bash 复制代码
pip install xgboost scikit-learn pandas numpy scipy flask joblib

启动方式

bash 复制代码
python your_file.py

首次启动会:

  • 自动生成演示训练集
  • 训练 XGBoost 模型
  • 保存模型到 artifacts/
  • 启动 Flask API,端口 5000

调用示例

bash 复制代码
curl -X POST http://127.0.0.1:5000/predict \
-H "Content-Type: application/json" \
-d '{
  "sender": "alert@micr0soft-reset.net",
  "subject": "Urgent account verification",
  "body": "Your account is at risk. Verify now using http://micr0soft-login-check.net",
  "urls": ["http://micr0soft-login-check.net"],
  "has_attachment": 0,
  "attachment_count": 0
}'

返回示例:

json 复制代码
{
  "is_phishing": true,
  "phishing_probability": 0.93,
  "action": "block",
  "reasons": [
    "邮件正文或标题包含高风险诱导词",
    "发件人域名不在可信名单中",
    "邮件中包含链接,存在跳转风险",
    "链接域名与发件人域名不一致"
  ]
}
相关推荐
用户83562907805111 小时前
Python 操作 PDF 附件:添加、查看与管理指南
后端·python
宇宙之一粟18 小时前
乐企版式文件生成平台
java·后端·python
学测绘的小杨1 天前
CompassFusion:一个从 GNSS 到 GNSS/INS 组合导航的独立工程包
python
zzzzzz3102 天前
当产品经理说这个很简单:我用Python自动化处理奇葩需求的实战指南
python·pycharm·产品经理
雪隐2 天前
个人电脑玩AI-06让5060 Ti给你打工——不光能画画,Qwen3-TTS还能学人说话,连我老板都信了!
人工智能·后端·python
兵慌码乱2 天前
面向桌面端的资产管理系统分层架构设计与核心模块实现
python·系统架构·sqlite·pyqt5·数据库设计·桌面应用开发·mvc架构
hboot2 天前
AI工程师第三课 - 机器学习基础
python·scikit-learn·kaggle
顾林海2 天前
Agent入门阶段-编程基础-Python:流程控制
python·agent·ai编程
呱呱复呱呱3 天前
Django CBV 源码解读:一个请求是怎么找到你的 get() 方法的
python·django
曲幽3 天前
刚部署的 LibreTranslate 频频翻车?我掏出了 20 年前的 StarDict 词典,用 FastAPI 搭了个本地词典翻译 API
python·fastapi·web·translate·goldendict·libretranslate·stardict·pystardict