机器学习入门(二)——逻辑回归 (Logistic Regression)

机器学习入门(二)------逻辑回归 (Logistic Regression)

第一部分:核心难点

这部分是逻辑回归的灵魂,也是面试和考试的重灾区,请务必彻底理解。

1. 核心转换:Sigmoid 函数

线性回归的输出是 z = w T x + b z = w^T x + b z=wTx+b,它的范围是 ( − ∞ , + ∞ ) (-\infty, +\infty) (−∞,+∞)。但分类问题(二分类)需要的输出是概率,范围必须在 [ 0 , 1 ] [0, 1] [0,1] 之间。怎么把直线"压"弯呢?

我们需要一个激活函数 ------Sigmoid函数

  • 公式 : g ( z ) = 1 1 + e − z g(z) = \frac{1}{1 + e^{-z}} g(z)=1+e−z1
  • 输入 : z z z (也就是线性回归的结果 w T x + b w^T x + b wTx+b)
  • 输出 :一个 0 到 1 之间的概率值。
    • 当 z z z 很大时, g ( z ) ≈ 1 g(z) \approx 1 g(z)≈1
    • 当 z z z 很小时, g ( z ) ≈ 0 g(z) \approx 0 g(z)≈0
    • 当 z = 0 z = 0 z=0 时, g ( z ) = 0.5 g(z) = 0.5 g(z)=0.5(这是默认的分类阈值)

划重点 :逻辑回归本质上就是 线性回归 + Sigmoid激活函数

2. 损失函数:对数似然损失 (Log Loss)

这是与线性回归最大的不同点!

为什么不能用均方误差 (MSE)?

在线性回归中我们用 MSE( ( y p r e d − y t r u e ) 2 (y_{pred} - y_{true})^2 (ypred−ytrue)2)。如果在逻辑回归中套用 Sigmoid 后再用 MSE,得到的损失函数图像是非凸的 (Non-convex),也就是像连绵起伏的山峦,有很多局部最低点,梯度下降很容易卡在半山腰,找不到全局最优解。

解决方案:对数似然损失

我们采用对数损失函数,公式如下:
C o s t ( h θ ( x ) , y ) = { − log ⁡ ( h θ ( x ) ) if y = 1 − log ⁡ ( 1 − h θ ( x ) ) if y = 0 Cost(h_\theta(x), y) = \begin{cases} -\log(h_\theta(x)) & \text{if } y = 1 \\ -\log(1 - h_\theta(x)) & \text{if } y = 0 \end{cases} Cost(hθ(x),y)={−log(hθ(x))−log(1−hθ(x))if y=1if y=0

合并后的完整公式(一定要眼熟):
J ( θ ) = − 1 m ∑ i = 1 m [ y ( i ) log ⁡ ( h θ ( x ( i ) ) ) + ( 1 − y ( i ) ) log ⁡ ( 1 − h θ ( x ( i ) ) ) ] J(\theta) = - \frac{1}{m} \sum_{i=1}^{m} [y^{(i)}\log(h_\theta(x^{(i)})) + (1-y^{(i)})\log(1-h_\theta(x^{(i)}))] J(θ)=−m1i=1∑m[y(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))]

  • 直观理解
    • 如果真实值是 1,预测概率越接近 1,损失越接近 0;如果预测成 0.1,惩罚巨大(损失趋向无穷)。
    • 如果真实值是 0,预测概率越接近 0,损失越接近 0。

3. 模型评估:混淆矩阵与精确率/召回率

这是实际工程中最关键的部分。因为在分类问题中(尤其是正负样本不平衡时),"准确率 (Accuracy)"往往会骗人。

A. 混淆矩阵 (Confusion Matrix)

预测为正 (Positive) 预测为负 (Negative)
真实为正 (True) TP (真正例) FN (假负例/漏报)
真实为负 (False) FP (假正例/误报) TN (真负例)

B. 核心指标 (背诵并理解场景)

  1. 精确率 (Precision) :查准率。预测为正的样本中,有多少是真的正?
    • 公式: T P / ( T P + F P ) TP / (TP + FP) TP/(TP+FP)
    • 场景:垃圾邮件拦截。你把正常邮件误判为垃圾邮件(FP)后果很严重,所以需要高精确率。
  2. 召回率 (Recall) :查全率。所有真的正样本中,你找出了多少?
    • 公式: T P / ( T P + F N ) TP / (TP + FN) TP/(TP+FN)
    • 场景癌症检测 。宁可误诊(FP),也不能漏诊(FN),漏掉一个病人就是人命关天。所以此时召回率最重要
  3. F1-Score :精确率和召回率的调和平均数。
    • 当需要两者兼顾时使用。

C. ROC曲线与AUC指标

  • ROC曲线:不管分类阈值怎么变,画出 TPR (召回率) 和 FPR (假正率) 的关系图。
  • AUC (Area Under Curve) :ROC曲线下的面积。
    • AUC = 0.5:瞎猜(随机)。
    • AUC = 1.0:完美分类器。
    • 优势:AUC 指标不怕样本不平衡(比如正样本只有1%,负样本99%时,AUC依然能客观评价模型好坏)。

第二部分:基础内容

这部分内容与线性回归逻辑相似,快速掌握即可。

1. 应用场景

  • 广告点击预测:用户会不会点击这个广告?(是/否)
  • 金融风控:这笔贷款会不会违约?(是/否)
  • 医疗诊断:是否患病?

2. 优化算法

依然使用梯度下降 (Gradient Descent)

尽管损失函数变了(从MSE变成了Log Loss),但求导后的梯度更新公式在形式上与线性回归惊人地相似(只是 h ( x ) h(x) h(x) 的定义不同了)。

3. API 调用 (Scikit-Learn)

python 复制代码
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, roc_auc_score

# 1. 实例化 (solver参数虽然多,默认即可,主要关注C正则化力度)
# C越小,正则化越强(防过拟合);C越大,正则化越弱
estimator = LogisticRegression(C=1.0) 

# 2. 训练
estimator.fit(x_train, y_train)

# 3. 预测
y_predict = estimator.predict(x_test)

# 4. 评估 (一定要看精确率和召回率,不要只看score)
print(classification_report(y_test, y_predict))
print("AUC指标:", roc_auc_score(y_test, y_predict))

第三部分:线性回归 vs 逻辑回归 (对比记忆)

特性 线性回归 (Linear Regression) 逻辑回归 (Logistic Regression)
解决问题 回归(预测连续数字) 分类(预测类别/概率)
核心公式 y = w T x + b y = w^T x + b y=wTx+b y = sigmoid ( w T x + b ) y = \text{sigmoid}(w^T x + b) y=sigmoid(wTx+b)
输出范围 ( − ∞ , + ∞ ) (-\infty, +\infty) (−∞,+∞) [ 0 , 1 ] [0, 1] [0,1]
损失函数 均方误差 (MSE) 对数似然损失 (Log Loss)
评估指标 MSE, RMSE 精确率, 召回率, AUC
相关推荐
枫叶林FYL几秒前
【机器学习与智慧医疗】糖尿病视网膜病变视力丧失预测:贝叶斯估计与威布尔分布
大数据·人工智能·机器学习
清水白石0083 分钟前
从“点一下导出”到生产级任务队列:Python 异步导出系统设计全景解析
java·数据库·python
Godspeed Zhao8 分钟前
从零开始学AI17——SVM的数学支撑知识
算法·机器学习·支持向量机
香蕉鼠片10 分钟前
CUDA、PyTorch、Transformers、PEFT 全栈详解
人工智能·pytorch·python
MediaTea11 分钟前
PyTorch:张量与基础计算模块
人工智能·pytorch·python·深度学习·机器学习
浪子sunny11 分钟前
2026股票实时行情数据Skills技能分享
大数据·人工智能·python
ゆづき12 分钟前
假如编程语言们有外号
java·c语言·c++·python·学习·c#·生活
深度学习lover12 分钟前
<数据集>yolo 电线杆识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·电线杆识别
阳明山水16 分钟前
LightGBM调优降MAPE至19%关键策略
人工智能·机器学习·微信·微信公众平台·微信开放平台
2301_8039346110 小时前
Go语言如何做网络爬虫_Go语言爬虫开发教程【指南】
jvm·数据库·python