《Hands-On Machine Learning with Scikit-Learn, Keras & TensorFlow》第一章读书笔记

第一部分:理论篇

1. 什么是机器学习?

核心定义

机器学习是让计算机从数据中学习的科学,而无需显式编程。

经典定义

  • Arthur Samuel (1959):

"让计算机无需明确编程就具备学习能力"

  • Tom Mitchell 的工程定义:

"如果一个程序通过经验 E 在某任务 T 上获得性能改善 P,则称其学习"


2. 为什么使用机器学习?

传统编程方法(以垃圾邮件检测为例):

  • 研究垃圾邮件特征
  • 编写规则检测这些特征
  • 测试并重复改进规则

缺点:

❗ 规则复杂度指数级增长

❗ 维护成本高昂

❗ 无法适应新型垃圾邮件

机器学习方法:

✅ 自动从样本中学习特征模式

✅ 动态适应新型垃圾邮件

✅ 维护和扩展成本显著降低


3. 机器学习的应用类型

适用场景:

  1. 传统方法难以制定明确规则的问题
  2. 传统解决方案过于复杂的问题
  3. 需适应环境变化的动态问题
  4. 从复杂数据中提取洞察

典型应用:

  • 图像分类
  • 语音识别
  • 垃圾邮件检测
  • 客户细分
  • 预测分析
  • 异常检测

4. 机器学习系统的类型

按训练监督划分

类型 特征 典型任务 示例
监督学习 训练数据含标签 分类/回归 垃圾邮件分类
无监督学习 训练数据无标签 聚类/降维 客户分群
半监督学习 部分数据有标签 混合任务 照片分类服务
强化学习 基于奖惩机制 动态决策 游戏AI/机器人行走

按学习方式划分:

  • 批量学习:全量数据训练 ➔ 离线更新
  • 在线学习:流式数据训练 ➔ 实时适应

按工作方式划分:

  • 基于实例学习:记忆样本 ➔ 相似度匹配预测
  • 基于模型学习:构建模型 ➔ 参数化预测

5. 机器学习的主要挑战

数据挑战

  1. 数据量不足
  2. 样本代表性缺失
  3. 低质量数据(噪声/错误)
  4. 无关特征干扰

算法挑战

  • 过拟合:模型过度记忆训练数据细节
  • 欠拟合:模型未能捕捉数据基本规律

6. 测试和验证

核心原则

  • 数据划分:训练集(70-80%)/验证集(10-15%)/测试集(10-15%)
  • 交叉验证:K折交叉验证提升评估可靠性
  • 超参数调优:必须使用独立验证集

7. 实践建议

✅ 构建端到端流程:数据收集→预处理→建模→评估→部署

✅ 先尝试简单模型(如线性回归)建立基线

✅ 数据质量 > 算法复杂度

✅ 选择与业务目标匹配的评估指标(如F1分数/ROC-AUC)

严防数据泄露:确保训练数据不包含测试集信息


第二部分:习题篇

基础概念题

Q1. 如何定义机器学习?

机器学习是关于构建能够从数据中学习的系统。学习的含义是:在某个任务上,根据某种性能度量标准,不断提升表现。


Q2. 机器学习在哪些场景中表现突出?

✅ 四类典型场景:

  1. 无明确算法解决方案的问题
  2. 替代人工调优的复杂规则系统
  3. 需动态适应环境变化的系统
  4. 从数据中挖掘隐含规律(数据挖掘)

Q3. 什么是带标签的训练集?

包含每个样本的期望输出(即标签)的训练数据集。例如:邮件数据集中的每封邮件被标注为"垃圾邮件"或"非垃圾邮件"。


学习类型辨析

Q4. 监督学习的典型任务是什么?

两类核心任务:

  • 回归(预测连续值,如房价预测)
  • 分类(预测离散类别,如图像识别)

Q5. 无监督学习的常见任务有哪些?

四类典型应用:

  1. 聚类(客户分群)
  2. 可视化(高维数据降维展示)
  3. 降维(特征压缩)
  4. 关联规则学习(购物篮分析)

Q6. 如何选择机器人行走地形的学习算法?

强化学习为首选方案:

  • 通过奖惩机制学习动态决策
  • 天然适配环境交互场景

⚠️ 注:监督学习需预定义行为标签,半监督学习需部分标注,二者均不直接适配该场景。


Q7. 客户分群应使用哪种算法?

分情况讨论:

  • 无预定义群体 → 聚类算法(无监督学习,如K-Means)
  • 已知群体类别 → 分类算法(监督学习,需标注数据)

系统特性与挑战

Q8. 在线学习系统的核心特征是什么?

⚡ 关键特性:

  • 增量学习:持续处理数据流
  • 实时适应:动态更新模型参数
  • 内存高效:无需存储历史全量数据

典型场景:金融实时风控、新闻推荐系统


Q9. 什么是外部存储(Out-of-Core)学习?

定义:

处理超出内存容量的超大规模数据时,将数据分批次加载到内存中进行增量训练的技术。

实现方式:

  • 数据分块(Mini-Batch)
  • 结合在线学习策略

Q10. 基于实例学习的原理是什么?

工作机制:

  1. 存储全部训练样本
  2. 新数据输入时,计算与存储样本的相似度(如欧氏距离)
  3. 根据最相似样本的标签进行预测

典型算法:K-近邻(KNN)


模型与参数解析

Q11. 模型参数 vs 超参数的区别?

特性 模型参数 超参数
定义 模型内部的权重(如线性回归斜率) 控制训练过程的配置参数
学习方式 通过训练数据自动优化 人工设定或自动调优(如网格搜索)
示例 神经网络权重、SVM支持向量 学习率、正则化系数、树的最大深度

Q12. 基于模型的学习如何运作?

三阶段过程:

  1. 目标:寻找最优模型参数,最小化损失函数(含正则化项)
  2. 训练策略:梯度下降、随机优化等
  3. 预测方式:将新数据输入参数化模型函数(如 )

实战挑战与解决方案

Q13. 机器学习的四大核心挑战?

⚠️ 关键瓶颈:

  1. 数据不足 → 模型难以捕捉规律
  2. 数据质量差 → 噪声干扰模型学习
  3. 样本不具代表性 → 泛化能力低下
  4. 特征信息量不足 → 模型性能天花板受限

Q14. 过拟合的识别与解决方案

问题识别

  • 训练集准确率高 ➔ 测试集准确率显著下降

三类解决方案

  1. 数据增强:收集更多数据/数据扩增(如图像旋转、添加噪声)
  2. 模型简化:减少网络层数、降低多项式次数、增加正则化
  3. 数据清洗:剔除异常样本、修复标签错误

Q15. 测试集的核心作用与误用风险

核心作用

  • 作为"未知数据"的代理,评估模型泛化性能

误用风险

  • 若用测试集调参 → 模型间接学习测试集分布 → 泛化误差评估虚高
  • 后果:线上部署后性能显著低于预期

Q16. 验证集 vs 训练-开发集的用途

数据集类型 核心用途 典型使用场景
验证集 模型选择与超参数调优 比较不同算法/参数组合效果
训练-开发集 检测训练数据与验证数据的分布偏移 数据分布不一致时的诊断工具

Q17. 数据泄露的预防策略

关键原则:

  • 严格隔离训练集、验证集、测试集
  • 避免在预处理阶段使用全量数据(如归一化需仅基于训练集统计量)
  • 时序数据需按时间顺序划分(禁止随机划分)

Q18. 如何诊断训练数据与验证数据的分布偏移?

训练-开发集(Train-Dev Set)的使用方法:

  1. 数据划分:
  • 训练集(70%) → 模型训练
  • 训练-开发集(10%) → 从训练集划分,用于检测过拟合
  • 验证集(20%) → 保持独立分布
  1. 诊断逻辑:
  • 若模型在 训练集 表现好,但在 训练-开发集 表现差 → 过拟合
  • 若模型在 训练-开发集 表现好,但在 验证集 表现差 → 数据分布偏移
  1. 解决方案:
  • 调整训练数据,使其更接近真实场景分布
  • 使用数据增强技术模拟验证集特征

Q19. 为什么不能直接用测试集调参?

⚠️ 核心风险:

  • 测试集污染:超参数调优本质上是"学习"测试集分布的过程
  • 性能虚高:模型会过度适配测试集特性,导致:
  • 线上部署效果显著低于测试集指标
  • 失去对真实未知数据的泛化能力

正确流程:

  1. 原始数据 → 划分训练集、验证集、测试集
  2. 用训练集训练模型
  3. 用验证集调参/选择模型
  4. 用测试集仅做最终评估(且仅限一次!)

第三部分:实践篇

「链接」

环境设置

验证Python版本要求(3.7及以上)

复制代码
import sys

assert sys.version_info >= (3, 7)  # 检查Python版本是否满足要求

验证Scikit-Learn版本要求(≥1.0.1)

复制代码
from packaging import version
import sklearn

assert version.parse(sklearn.__version__) >= version.parse("1.0.1")  # 检查sklearn版本是否达标

设置matplotlib默认字体大小,优化图表显示效果

复制代码
import matplotlib.pyplot as plt

plt.rc('font', size=12)            # 全局字体大小
plt.rc('axes', labelsize=14, titlesize=14)  # 坐标轴标签和标题大小
plt.rc('legend', fontsize=12)       # 图例字体大小
plt.rc('xtick', labelsize=10)       # X轴刻度标签大小
plt.rc('ytick', labelsize=10)       # Y轴刻度标签大小

设置随机种子保证结果可复现

复制代码
import numpy as np

np.random.seed(42)  # 固定随机数生成器的种子

代码示例1-1

导入必要库并加载数据集,展示GDP与生活满意度的关系

复制代码
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression

# 下载并准备数据
data_root = "https://github.com/ageron/data/raw/main/"
lifesat = pd.read_csv(data_root + "lifesat/lifesat.csv")
X = lifesat[["GDP per capita (USD)"]].values  # 提取GDP特征
y = lifesat[["Life satisfaction"]].values     # 提取目标变量

# 可视化数据
lifesat.plot(kind='scatter', grid=True,
             x="GDP per capita (USD)", y="Life satisfaction")
plt.axis([23_500, 62_500, 4, 9])  # 设置坐标轴范围
plt.show()

# 选择线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X, y)

# 预测塞浦路斯数据
X_new = [[37_655.2]]  # 塞浦路斯2020年人均GDP
print(model.predict(X_new)) # 输出预测值 [[6.30165767]]

输出结果:

复制代码
[[6.30165767]]

将线性回归模型替换为K近邻回归(k=3)

复制代码
from sklearn.neighbors import KNeighborsRegressor

# 创建KNN回归器实例
model = KNeighborsRegressor(n_neighbors=3)

# 训练模型
model.fit(X, y)

# 进行预测
print(model.predict(X_new)) # 输出预测值 [[6.33333333]]

输出结果:

复制代码
[[6.33333333]]

数据与图表生成

以下是生成lifesat.csv数据集的代码。

创建图像保存函数

复制代码
from pathlib import Path

# 图像保存路径设置
IMAGES_PATH = Path() / "images" / "fundamentals"
IMAGES_PATH.mkdir(parents=True, exist_ok=True)  # 创建多级目录

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = IMAGES_PATH / f"{fig_id}.{fig_extension}"
    if tight_layout:
        plt.tight_layout()  # 自动调整子图间距
    plt.savefig(path, format=fig_extension, dpi=resolution)  # 保存图像

加载和处理生活满意度数据

自动下载原始数据文件

复制代码
import urllib.request

datapath = Path() / "datasets" / "lifesat"
datapath.mkdir(parents=True, exist_ok=True)  # 创建数据存储目录

data_root = "https://github.com/ageron/data/raw/main/"
for filename in ("oecd_bli.csv", "gdp_per_capita.csv"):
    if not (datapath / filename).is_file():
        print("下载中", filename)
        url = data_root + "lifesat/" + filename
        urllib.request.urlretrieve(url, datapath / filename)  # 下载缺失文件

输出:

复制代码
下载中 oecd_bli.csv
下载中 gdp_per_capita.csv

加载预处理后的GDP数据(仅保留2020年)

复制代码
gdp_year = 2020
gdppc_col = "GDP per capita (USD)"
lifesat_col = "Life satisfaction"

gdp_per_capita = gdp_per_capita[gdp_per_capita["Year"] == gdp_year]  # 筛选年份
gdp_per_capita = gdp_per_capita.drop(["Code", "Year"], axis=1)       # 删除冗余列
gdp_per_capita.columns = ["Country", gdppc_col]                      # 重命名列
gdp_per_capita.set_index("Country", inplace=True)                    # 设置国家为索引

gdp_per_capita.head()  # 显示前五行数据

处理OECD BLI数据(提取生活满意度指标)

复制代码
oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]       # 筛选总评数据
oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")  # 数据透视

oecd_bli.head()  # 显示预处理后的数据结构

合并OECD生活满意度数据与GDP数据,创建完整数据集

复制代码
full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,
                              left_index=True, right_index=True)  # 按国家索引进行左连接
full_country_stats.sort_values(by=gdppc_col, inplace=True)        # 按GDP排序
full_country_stats = full_country_stats[[gdppc_col, lifesat_col]] # 保留关键列

full_country_stats.head()  # 显示合并后的前五行数据

设置GDP过滤范围,创建演示用子集(用于避免过拟合示例)

复制代码
min_gdp = 23_500
max_gdp = 62_500

country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &
                                   (full_country_stats[gdppc_col] <= max_gdp)]  # GDP区间过滤
country_stats.head()  # 显示筛选后的数据

保存处理后的数据集

复制代码
country_stats.to_csv(datapath / "lifesat.csv")        # 保存筛选数据集
full_country_stats.to_csv(datapath / "lifesat_full.csv")  # 保存完整数据集

绘制带国家标注的散点图

复制代码
country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 创建基础散点图

# 定义各国标注位置
position_text = {
    "Turkey": (29_500, 4.2),
    "Hungary": (28_000, 6.9),
    "France": (40_000, 5),
    "New Zealand": (28_000, 8.2),
    "Australia": (50_000, 5.5),
    "United States": (59_000, 5.3),
    "Denmark": (46_000, 8.5)
}

# 添加国家标注和指示箭头
for country, pos_text in position_text.items():
    pos_data_x = country_stats[gdppc_col].loc[country]
    pos_data_y = country_stats[lifesat_col].loc[country]
    country = "U.S." if country == "United States" else country  # 简化显示名称
    plt.annotate(country, xy=(pos_data_x, pos_data_y),
                 xytext=pos_text, fontsize=12,
                 arrowprops=dict(facecolor='black', width=0.5,
                                 shrink=0.08, headwidth=5))  # 添加箭头标注
    plt.plot(pos_data_x, pos_data_y, "ro")  # 标记数据点

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])  # 固定坐标范围
save_fig('money_happy_scatterplot')  # 保存图像
plt.show()

提取高亮国家数据并按GDP排序

复制代码
highlighted_countries = country_stats.loc[list(position_text.keys())]
highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)  # GDP升序排列

绘制不同参数组合的线性模型对比图

复制代码
country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 基础散点图

X = np.linspace(min_gdp, max_gdp, 1000)  # 生成连续GDP值

# 绘制三组不同参数的直线
w1, w2 = 4.2, 0
plt.plot(X, w1 + w2 * 1e-5 * X, "r")  # 红色水平线
plt.text(40_000, 4.9, fr"$\theta_0 = {w1}$", color="r")
plt.text(40_000, 4.4, fr"$\theta_1 = {w2}$", color="r")

w1, w2 = 10, -9
plt.plot(X, w1 + w2 * 1e-5 * X, "g")  # 绿色下降趋势线
plt.text(26_000, 8.5, fr"$\theta_0 = {w1}$", color="g")
plt.text(26_000, 8.0, fr"$\theta_1 = {w2} \times 10^{{-5}}$", color="g")

w1, w2 = 3, 8
plt.plot(X, w1 + w2 * 1e-5 * X, "b")  # 蓝色上升趋势线
plt.text(48_000, 8.5, fr"$\theta_0 = {w1}$", color="b")
plt.text(48_000, 8.0, fr"$\theta_1 = {w2} \times 10^{{-5}}$", color="b")

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])
save_fig('tweaking_model_params_plot')
plt.show()

训练线性回归模型并输出参数

复制代码
from sklearn import linear_model

X_sample = country_stats[[gdppc_col]].values  # 特征矩阵
y_sample = country_stats[[lifesat_col]].values  # 目标向量

lin1 = linear_model.LinearRegression()  # 创建线性回归器
lin1.fit(X_sample, y_sample)  # 训练模型

t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]  # 获取截距和斜率
print(f"θ0={t0:.2f}, θ1={t1:.2e}")  # 格式化输出参数

输出结果:

复制代码
θ0=3.75, θ1=6.78e-05

绘制最佳拟合线可视化

复制代码
country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 基础散点图

X = np.linspace(min_gdp, max_gdp, 1000)
plt.plot(X, t0 + t1 * X, "b")  # 绘制回归线

# 添加参数标注
plt.text(max_gdp - 20_000, min_life_sat + 1.9,
         fr"$\theta_0 = {t0:.2f}$", color="b")
plt.text(max_gdp - 20_000, min_life_sat + 1.3,
         fr"$\theta_1 = {t1 * 1e5:.2f} \times 10^{{-5}}$", color="b")

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])
save_fig('best_fit_model_plot')
plt.show()

获取塞浦路斯的GDP数据

复制代码
cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc["Cyprus"]  # 通过国家名索引
cyprus_gdp_per_capita  # 显示值

使用训练好的模型进行预测

复制代码
cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]
cyprus_predicted_life_satisfaction  # 显示预测结果

可视化预测结果

复制代码
country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 基础散点图

X = np.linspace(min_gdp, max_gdp, 1000)
plt.plot(X, t0 + t1 * X, "b")  # 绘制回归线

# 调整参数标注位置
plt.text(min_gdp + 22_000, max_life_sat - 1.1,
         fr"$\theta_0 = {t0:.2f}$", color="b")
plt.text(min_gdp + 22_000, max_life_sat - 0.6,
         fr"$\theta_1 = {t1 * 1e5:.2f} \times 10^{{-5}}$", color="b")

# 添加预测线标注
plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],
         [min_life_sat, cyprus_predicted_life_satisfaction], "r--")  # 红色虚线
plt.text(cyprus_gdp_per_capita + 1000, 5.0,
         fr"Prediction = {cyprus_predicted_life_satisfaction:.2f}", color="r")
plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, "ro")  # 红点标记

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])
plt.show()

提取GDP区间外的异常数据点

复制代码
missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |
                                  (full_country_stats[gdppc_col] > max_gdp)]  # 筛选超出GDP范围的数据
missing_data  # 显示异常数据点

定义缺失国家的标注位置坐标

复制代码
position_text_missing_countries = {
    "South Africa": (20_000, 4.2),    # 南非标注位置
    "Colombia": (6_000, 8.2),         # 哥伦比亚
    "Brazil": (18_000, 7.8),          # 巴西
    "Mexico": (24_000, 7.4),          # 墨西哥
    "Chile": (30_000, 7.0),           # 智利
    "Norway": (51_000, 6.2),          # 挪威
    "Switzerland": (62_000, 5.7),     # 瑞士
    "Ireland": (81_000, 5.2),         # 爱尔兰
    "Luxembourg": (92_000, 4.7),      # 卢森堡
}

绘制完整数据分布及异常点标注

复制代码
full_country_stats.plot(kind='scatter', figsize=(8, 3),
                        x=gdppc_col, y=lifesat_col, grid=True)  # 全数据散点图

# 添加异常点标注
for country, pos_text in position_text_missing_countries.items():
    pos_data_x, pos_data_y = missing_data.loc[country]  # 获取实际坐标
    plt.annotate(country, xy=(pos_data_x, pos_data_y),
                 xytext=pos_text, fontsize=12,
                 arrowprops=dict(facecolor='black', width=0.5,
                                 shrink=0.08, headwidth=5))  # 带箭头文本标注
    plt.plot(pos_data_x, pos_data_y, "rs")  # 红色方块标记

# 绘制局部数据拟合线
X = np.linspace(0, 115_000, 1000)
plt.plot(X, t0 + t1 * X, "b:")  # 蓝色虚线表示局部数据模型

# 使用完整数据训练新模型
lin_reg_full = linear_model.LinearRegression()
Xfull = np.c_[full_country_stats[gdppc_col]]  # 二维数组转换
yfull = np.c_[full_country_stats[lifesat_col]]
lin_reg_full.fit(Xfull, yfull)  # 全数据训练

# 绘制全数据拟合线
t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]
plt.plot(X, t0full + t1full * X, "k")  # 黑色实线表示全数据模型

plt.axis([0, 115_000, min_life_sat, max_life_sat])  # 扩展坐标范围
save_fig('representative_training_data_scatterplot')
plt.show()

构建过拟合多项式回归模型

复制代码
from sklearn import preprocessing
from sklearn import pipeline

full_country_stats.plot(kind='scatter', figsize=(8, 3),
                        x=gdppc_col, y=lifesat_col, grid=True)  # 全数据散点图

# 创建多项式回归流水线
poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)  # 10次多项式
scaler = preprocessing.StandardScaler()        # 标准化
lin_reg2 = linear_model.LinearRegression()     # 线性回归

pipeline_reg = pipeline.Pipeline([
    ('poly', poly),    # 特征多项式扩展
    ('scal', scaler),  # 数据标准化
    ('lin', lin_reg2)])  # 线性回归
pipeline_reg.fit(Xfull, yfull)  # 训练高阶模型
curve = pipeline_reg.predict(X[:, np.newaxis])  # 生成预测曲线
plt.plot(X, curve)  # 绘制过拟合曲线

plt.axis([0, 115_000, min_life_sat, max_life_sat])
save_fig('overfitting_model_plot')  # 保存过拟合示例图
plt.show()

筛选国家名称含'W'字母的国家数据

复制代码
w_countries = [c for c in full_country_stats.index if "W" in c.upper()]  # 列表推导式筛选
full_country_stats.loc[w_countries][lifesat_col]  # 显示生活满意度

获取所有含'W'国家的GDP数据

复制代码
all_w_countries = [c for c in gdp_per_capita.index if "W" in c.upper()]
gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)  # GDP升序排列

对比不同回归模型效果

复制代码
# 创建复合图表
country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8, 3))
missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,
                  marker="s", color="r", grid=True, ax=plt.gca())  # 红色方块标记异常点

# 绘制三种回归线
X = np.linspace(0, 115_000, 1000)
plt.plot(X, t0 + t1*X, "b:", label="基于部分数据的线性模型")  # 局部数据模型
plt.plot(X, t0full + t1full * X, "k-", label="全数据线性模型")  # 全数据模型

# 训练岭回归模型
ridge = linear_model.Ridge(alpha=10**9.5)  # 设置正则化强度
X_sample = country_stats[[gdppc_col]]
y_sample = country_stats[[lifesat_col]]
ridge.fit(X_sample, y_sample)  # 训练正则化模型
t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]
plt.plot(X, t0ridge + t1ridge * X, "b--", 
         label="部分数据正则化线性模型")  # 蓝色虚线

plt.legend(loc="lower right")  # 右下角图例
plt.axis([0, 115_000, min_life_sat, max_life_sat])
save_fig('ridge_model_plot')  # 保存岭回归对比图
plt.show()

DeepSeek 代码分析

1. 核心功能

  • 任务类型:单变量回归任务(GDP per capita ➔ Life satisfaction)
  • 模型架构:对比线性回归/K近邻/多项式回归/岭回归等多种方法
  • 特色:通过可视化展示模型差异,演示过拟合现象

2. 关键组件

A[环境配置] --> B[数据获取]--> C[数据预处理]--> D[特征工程]--> E[模型训练]--> F[结果可视化]--> G[模型对比]--> H[过拟合分析]

3. 数据流

原始数据 → 筛选年份 → 合并数据集 → 特征标准化 → 模型输入 → 预测输出 → 可视化展示

4. 创新点

  • 采用对比可视化技术:同时展示多个模型的拟合效果
  • 创新标注系统:动态生成带箭头的国家标注注释
  • 混合可视化策略:将原始数据点与模型预测曲线叠加展示
  • 过拟合演示:通过极端多项式回归直观展示过拟合现象

5. 改进建议 ▲ 数据预处理阶段:

  • 缺失值处理未显式体现(假设数据完整)
  • 特征标准化仅应用于部分模型(建议统一预处理流程)
  • 数据拆分未采用标准train-test split(可能引入数据泄露)

▲ 模型评估:

  • 缺乏量化评估指标(如RMSE/R²分数)
  • 未使用交叉验证评估模型稳定性
  • 超参数选择(如KNN的k值)未说明依据

▲ 代码结构:

  • 可封装重复可视化操作为函数
  • 建议使用sklearn Pipeline整合预处理步骤
  • 路径处理可改用pathlib的现代写法

6. 学习要点 重点关注的实现细节:

  • 数据透视操作:pivot()方法重构数据格式
  • 坐标轴标注技巧:fr" θ 0 = t 0 : . 2 f \theta_0 = {t0:.2f} θ0=t0:.2f"的LaTeX公式渲染
  • 混合可视化技术:基础散点图+预测曲线+标注元素的叠加方式
  • 正则化应用:Ridge回归的参数设置(alpha=10**9.5)对模型的影响

批判性思考补充 ▲ 数据代表性风险:

  • 筛选GDP范围时可能引入样本偏差(min_gdp=23_500)
  • 异常值处理采用简单截断法,可能丢失重要信息
  • 国家标注位置手工设定,缺乏自动化定位机制

▲ 模型可解释性:

  • 多项式回归的degree=10选择缺乏理论依据
  • 未进行特征重要性分析
  • 不同模型预测差异未进行统计显著性检验

代码亮点总结

  • 环境配置完整:版本检查+随机种子+可视化预设
  • 数据故事化呈现:通过渐进式可视化构建分析逻辑
  • 模型对比直观:多模型拟合曲线同图对比
  • 可复现设计:自动下载数据+路径创建+随机种子

建议重点关注代码中数据流转换的实现方式(特别是pandas操作)和可视化组件的构建逻辑,这是实现分析型机器学习项目的典型范式。

第四部分:思考篇

1. 预备知识要求

  1. Python编程基础
  2. 科学计算库:NumPy、Pandas、Matplotlib
  3. 基础数学概念:
  • 线性代数(向量、矩阵运算)
  • 微积分(理解神经网络训练原理)
  • 基础概率论
  • 基础统计学

2. 学习建议

  1. 不要过早跳入深度学习
  2. 建议先掌握机器学习基础
  3. 大多数问题可以用更简单的技术解决(如随机森林)
  4. 深度学习最适合:
  • 图像识别
  • 语音识别
  • 自然语言处理
  • 需要大量数据和计算资源的问题
相关推荐
hrrrrb34 分钟前
【机器学习】监督学习
人工智能·学习·机器学习
长桥夜波2 小时前
【第十九周】机器学习笔记08
人工智能·笔记·机器学习
深蓝岛3 小时前
LSTM与CNN融合建模的创新技术路径
论文阅读·人工智能·深度学习·机器学习·lstm
lzptouch4 小时前
蚁群(Ant Colony Optimization, ACO)算法
人工智能·算法·机器学习
Clain5 小时前
Ollama、LM Studio只是模型工具,这款工具比他俩更全面
人工智能·机器学习·llm
双翌视觉7 小时前
机器视觉的液晶电视OCA全贴合应用
人工智能·数码相机·机器学习·1024程序员节
青云交15 小时前
Java 大视界 -- Java 大数据在智能农业温室环境调控与作物生长模型构建中的应用
java·机器学习·传感器技术·数据处理·作物生长模型·智能农业·温室环境调控
浣熊-论文指导15 小时前
聚类与Transformer融合的六大创新方向
论文阅读·深度学习·机器学习·transformer·聚类
B站_计算机毕业设计之家19 小时前
预测算法:股票数据分析预测系统 股票预测 股价预测 Arima预测算法(时间序列预测算法) Flask 框架 大数据(源码)✅
python·算法·机器学习·数据分析·flask·股票·预测
GG向前冲19 小时前
【大数据】Spark MLlib 机器学习流水线搭建
大数据·机器学习·spark-ml