人工智能之核心基础 机器学习 第十六章 模型优化

人工智能之核心基础 机器学习

第十六章 模型优化


文章目录

  • [人工智能之核心基础 机器学习](#人工智能之核心基础 机器学习)
    • [16.1 过拟合与欠拟合](#16.1 过拟合与欠拟合)
      • [🎯 定义(用"考试"比喻)](#🎯 定义(用“考试”比喻))
      • [🔍 表现与原因](#🔍 表现与原因)
    • [16.2 解决过拟合的方法](#16.2 解决过拟合的方法)
      • [✅ 五大核心策略](#✅ 五大核心策略)
        • [1. **正则化(Regularization)**](#1. 正则化(Regularization))
        • [2. **交叉验证(Cross-Validation)**](#2. 交叉验证(Cross-Validation))
        • [3. **早停(Early Stopping)**](#3. 早停(Early Stopping))
        • [4. **增加数据量**](#4. 增加数据量)
        • [5. **数据增强(Data Augmentation)**](#5. 数据增强(Data Augmentation))
    • [16.3 超参数调优](#16.3 超参数调优)
      • [🔧 三大主流方法](#🔧 三大主流方法)
      • [📜 调优流程(标准实践)](#📜 调优流程(标准实践))
    • [16.4 模型融合](#16.4 模型融合)
      • [🤝 为什么融合有效?](#🤝 为什么融合有效?)
      • [1. 简单融合](#1. 简单融合)
      • [2. 进阶融合:堆叠(Stacking)](#2. 进阶融合:堆叠(Stacking))
    • [16.5 半监督/自监督模型优化技巧](#16.5 半监督/自监督模型优化技巧)
      • [🔑 三大关键优化点](#🔑 三大关键优化点)
        • [1. **伪标签筛选策略**](#1. 伪标签筛选策略)
        • [2. **自监督前置任务调优**](#2. 自监督前置任务调优)
        • [3. **无标签数据利用率提升**](#3. 无标签数据利用率提升)
    • [16.6 模型选择策略](#16.6 模型选择策略)
      • [🧭 选型决策树](#🧭 选型决策树)
      • [📊 模型选择对照表](#📊 模型选择对照表)
      • [💡 半监督/自监督 vs 传统范式](#💡 半监督/自监督 vs 传统范式)
    • [🎯 本章总结:泛化能力提升 Checklist](#🎯 本章总结:泛化能力提升 Checklist)
  • 资料关注

16.1 过拟合与欠拟合

🎯 定义(用"考试"比喻)

状态 训练表现 测试表现 比喻
欠拟合 "课本都没看懂,考试自然不会"
理想状态 "真正学会了知识,举一反三"
过拟合 极好 "死记硬背考题答案,换题就不会"

🔍 表现与原因

问题 典型表现 根本原因
欠拟合 - 训练误差高- 模型太简单(如线性模型拟合非线性) - 模型容量不足- 特征太少/质量差
过拟合 - 训练误差≈0- 验证误差远高于训练误差- 模型复杂度高 - 数据量少- 模型太复杂- 噪声多

💡 可视化诊断

python 复制代码
import matplotlib.pyplot as plt

plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
plt.legend()
if val_loss starts rising while train_loss keeps falling → 过拟合!

16.2 解决过拟合的方法

✅ 五大核心策略

1. 正则化(Regularization)
  • L1正则(Lasso) \\text{Loss} + \\lambda \\sum \|w_i\| → 自动特征选择
  • L2正则(Ridge) \\text{Loss} + \\lambda \\sum w_i\^2 → 权重衰减,防过大
python 复制代码
from sklearn.linear_model import LogisticRegression

# L2正则(默认)
lr_l2 = LogisticRegression(penalty='l2', C=1.0)  # C越小,正则越强

# L1正则
lr_l1 = LogisticRegression(penalty='l1', solver='liblinear', C=0.1)
2. 交叉验证(Cross-Validation)
  • 防止模型评估"运气好"
  • K折交叉验证:数据分K份,轮流做验证集
python 复制代码
from sklearn.model_selection import cross_val_score

scores = cross_val_score(model, X, y, cv=5)  # 5折
print(f"平均准确率: {scores.mean():.2%} ± {scores.std():.2%}")
3. 早停(Early Stopping)
  • 训练时监控验证损失,不再下降时停止
python 复制代码
# PyTorch示例
best_val_loss = float('inf')
patience = 5
counter = 0

for epoch in range(100):
    train(...)
    val_loss = validate(...)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        save_model()
    else:
        counter += 1
        if counter >= patience:
            break  # 早停
4. 增加数据量
  • 最根本的解决方法!但成本高
  • 半监督/自监督可缓解此问题
5. 数据增强(Data Augmentation)
  • 人工扩充数据多样性
  • 图像:旋转、裁剪、颜色抖动
  • 文本:同义词替换、随机删除
  • 表格:加噪声、SMOTE(少数类过采样)
python 复制代码
# 图像增强(用于训练)
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

⚠️ 注意 :测试时不要数据增强!


16.3 超参数调优

🔧 三大主流方法

方法 原理 优点 缺点
网格搜索 遍历所有组合 全面 计算爆炸(维度灾难)
随机搜索 随机采样组合 高效,常优于网格 可能漏掉最优
贝叶斯优化 基于历史结果智能采样 最高效 实现复杂

📜 调优流程(标准实践)

python 复制代码
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform, randint

# 定义参数空间
param_dist = {
    'n_estimators': randint(50, 200),
    'max_depth': [3, 5, 7, None],
    'learning_rate': uniform(0.01, 0.3)
}

# 随机搜索 + 交叉验证
from xgboost import XGBClassifier
model = XGBClassifier()

random_search = RandomizedSearchCV(
    model, param_dist, n_iter=50, 
    cv=5, scoring='accuracy', 
    random_state=42, n_jobs=-1
)

random_search.fit(X_train, y_train)

print("最佳参数:", random_search.best_params_)
print("最佳得分:", random_search.best_score_)

💡 进阶工具Optuna, Hyperopt(支持贝叶斯优化)


16.4 模型融合

🤝 为什么融合有效?

"三个臭皮匠,顶个诸葛亮" ------ 不同模型犯错方式不同,融合可互补!


1. 简单融合

投票法(分类)
  • 硬投票:多数表决
  • 软投票:平均预测概率
python 复制代码
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC

clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
clf3 = SVC(probability=True)

voting = VotingClassifier(
    estimators=[('lr', clf1), ('rf', clf2), ('svc', clf3)],
    voting='soft'  # 使用概率
)
voting.fit(X_train, y_train)
平均法(回归)
python 复制代码
pred = (model1.predict(X) + model2.predict(X) + model3.predict(X)) / 3

2. 进阶融合:堆叠(Stacking)

  • 基模型:多个不同模型(如LR、RF、SVM)
  • 元模型:学习如何组合基模型的输出
python 复制代码
from sklearn.ensemble import StackingClassifier

stacking = StackingClassifier(
    estimators=[('lr', lr), ('rf', rf), ('svc', svc)],
    final_estimator=LogisticRegression(),  # 元模型
    cv=5  # 用5折生成元特征
)
stacking.fit(X_train, y_train)

效果:通常比单一模型提升1~3%准确率!


16.5 半监督/自监督模型优化技巧

🔑 三大关键优化点

1. 伪标签筛选策略
  • 动态阈值:初期阈值高(如0.95),后期降低
  • 课程学习:先学简单样本(高置信度),再学难样本
  • 一致性正则:对同一无标签样本做两次增强,预测应一致
python 复制代码
# 动态阈值示例
initial_threshold = 0.95
final_threshold = 0.8
threshold = initial_threshold - (initial_threshold - final_threshold) * (epoch / max_epochs)
high_conf = proba.max(axis=1) > threshold
2. 自监督前置任务调优
  • 掩码比例:图像MAE常用75%,文本BERT用15%
  • 增强强度:对比学习中,增强太弱→学不到东西,太强→两个视角无关
  • 损失函数:SimSiam用余弦相似度,MAE用MSE
3. 无标签数据利用率提升
  • 分批加入:先用高质量无标签数据,再逐步扩大
  • 置信度加权:高置信伪标签权重高,低置信权重低
  • 对抗训练:让模型对输入扰动鲁棒

16.6 模型选择策略

🧭 选型决策树

大量全标注
少量标签+大量无标签
完全无标签
发现结构
预训练特征
有多少标签数据?
监督学习
半监督学习
目标任务是什么?
无监督学习
自监督学习


📊 模型选择对照表

场景 推荐模型 理由
表格数据 + 少量标签 XGBoost + 伪标签 树模型抗噪,伪标签简单有效
图像 + 无标签 MAE / SimSiam 预训练 视觉自监督SOTA
文本 + 少量标签 BERT微调 + 半监督 利用预训练语言知识
高维稀疏数据 PCA + Logistic Regression 降维去噪,线性模型稳定
非球形簇 + 无标签 DBSCAN 捕捉任意形状簇

💡 半监督/自监督 vs 传统范式

维度 监督学习 半监督 自监督
数据需求 大量标注 少量标注+大量无标注 仅无标注
开发成本 高(标注) 低(但需设计任务)
适用阶段 成熟业务 探索期/标注瓶颈 预训练/冷启动
典型产出 直接可用模型 改进版监督模型 通用特征提取器

实践
冷启动阶段 :自监督预训练
有少量标注后 :半监督微调
标注充足后:纯监督精调


🎯 本章总结:泛化能力提升 Checklist

问题 解决方案 工具
过拟合 正则化、早停、数据增强 sklearn, 手动实现
欠拟合 增加模型复杂度、特征工程 更深网络、新特征
超参数不佳 随机搜索、贝叶斯优化 RandomizedSearchCV, Optuna
单模型不稳定 模型融合 VotingClassifier, StackingClassifier
半监督效果差 优化伪标签策略、一致性正则 动态阈值、增强一致性
自监督特征弱 调整前置任务、增强策略 掩码比例、对比学习温度

📘 建议

  1. 先保证数据质量(第15章)
  2. 从简单模型开始(如逻辑回归)
  3. 用交叉验证评估
  4. 逐步引入复杂技术(正则→融合→半监督)

🌟 提醒
泛化能力不是靠一个神奇算法,而是系统性工程------数据、模型、训练、评估缺一不可!

资料关注

公众号:咚咚王

gitee:https://gitee.com/wy18585051844/ai_learning

《Python编程:从入门到实践》

《利用Python进行数据分析》

《算法导论中文第三版》

《概率论与数理统计(第四版) (盛骤) 》

《程序员的数学》

《线性代数应该这样学第3版》

《微积分和数学分析引论》

《(西瓜书)周志华-机器学习》

《TensorFlow机器学习实战指南》

《Sklearn与TensorFlow机器学习实用指南》

《模式识别(第四版)》

《深度学习 deep learning》伊恩·古德费洛著 花书

《Python深度学习第二版(中文版)【纯文本】 (登封大数据 (Francois Choliet)) (Z-Library)》

《深入浅出神经网络与深度学习+(迈克尔·尼尔森(Michael+Nielsen)》

《自然语言处理综论 第2版》

《Natural-Language-Processing-with-PyTorch》

《计算机视觉-算法与应用(中文版)》

《Learning OpenCV 4》

《AIGC:智能创作时代》杜雨+&+张孜铭

《AIGC原理与实践:零基础学大语言模型、扩散模型和多模态模型》

《从零构建大语言模型(中文版)》

《实战AI大模型》

《AI 3.0》

相关推荐
电商API_180079052472 小时前
1688商品详情采集API全解析:技术原理、实操指南与业务落地
大数据·前端·人工智能·网络爬虫
向上的车轮2 小时前
麦肯锡《智能体、机器人与我们:AI时代的技能协作》
人工智能·机器人
叫我:松哥2 小时前
基于Flask框架开发的二手房数据分析与推荐管理平台,集成大数据分析、机器学习预测和智能推荐技术
大数据·python·深度学习·机器学习·数据分析·flask
2501_945837432 小时前
数字经济的 “安全基石”—— 云服务器零信任架构如何筑牢数据安全防线
人工智能
2501_942191772 小时前
【深度学习应用】香蕉镰刀菌症状识别与分类:基于YOLO13-C3k2-MBRConv5模型的实现与分析
人工智能·深度学习·分类
Coder_Boy_2 小时前
基于SpringAI的在线考试系统-DDD(领域驱动设计)核心概念及落地架构全总结
java·大数据·人工智能·spring boot·架构·ddd·tdd
AI小怪兽2 小时前
YOLO26:面向实时目标检测的关键架构增强与性能基准测试
人工智能·yolo·目标检测·计算机视觉·目标跟踪·架构
知乎的哥廷根数学学派2 小时前
基于卷积特征提取和液态神经网络的航空发动机剩余使用寿命预测算法(python)
人工智能·pytorch·python·深度学习·神经网络·算法
高洁012 小时前
AIGC技术与进展(2)
人工智能·python·深度学习·机器学习·数据挖掘