机器学习中的回归与分类模型:线性回归、逻辑回归与多分类

在机器学习领域,回归和分类是两类重要的任务,它们各自有着不同的应用场景和模型构建方式。本文将详细介绍线性回归、逻辑回归以及多分类任务的相关内容,包括数据预处理、模型定义、损失函数的选择以及评估指标的计算。

一、线性回归(连续值预测)

线性回归是一种用于预测连续数值型目标变量的模型。在使用线性回归模型时,以下是一些关键步骤:

  1. 数据预处理

    对特征矩阵 X 和目标变量 y 都需要进行标准化或归一化处理。这样做的目的是为了使数据具有相同的尺度,从而提高模型训练的稳定性和效率。如果在生成特征数据时使用了时间窗口(此时 x 是三维的),需要对 x 进行形状重塑,即 x = x.reshape(-1, x.shape[1] * x.shape[2]);若没有使用时间窗口,则无需进行这一步骤。同时,目标变量 y 需要转换为二维数组,使用 y = y.reshape(-1, 1)

  2. 模型定义

    使用 torch 构建线性回归模型,定义如下:

    复制代码
    model = torch.nn.Linear(in_features=x.shape[1], out_features=1)

    其中 in_features 是输入特征的数量,out_features 设置为 1,因为我们预测的是单个连续值。

损失函数

线性回归通常使用均方误差损失函数(Mean Squared Error Loss,MSELoss),定义如下:

复制代码
import torch.nn as nn
loss_fn = nn.MSELoss()

二、逻辑回归(二分类任务)

逻辑回归是一种用于解决二分类问题的模型,它通过将线性回归的输出经过一个 Sigmoid 函数映射到 [0, 1] 区间,从而得到样本属于正类的概率。

  1. 数据预处理

    对特征矩阵 X 进行标准化或归一化处理。目标变量 y 需要转换为二维的 torch.Tensor 类型,并且数据类型为 float,使用 torch.tensor(y, dtype=torch.float).reshape(-1, 1)

  2. 模型定义

    逻辑回归模型由一个线性层和一个 Sigmoid 激活函数组成,定义如下:

    python 复制代码
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=x.shape[1], out_features=1),
        torch.nn.Sigmoid()
    )

    Sigmoid 函数将线性层的输出转换为概率值。

  3. 损失函数
    对于二分类问题,常用的损失函数是二元交叉熵损失函数(Binary Cross Entropy Loss,BCELoss),定义如下:

    python 复制代码
    loss_fn = nn.BCELoss()
  4. 评估指标
    在模型评估时,将模型的输出 h 经过阈值处理(通常阈值为 0.5)转换为类别标签,即 h = (h > 0.5).int()。然后计算准确率,公式为 acc = (h == y).float().mean()

三、多分类任务

多分类任务是指将样本分为多个不同的类别。

数据预处理

同样对特征矩阵 X 进行标准化或归一化处理。

模型定义

多分类模型通常由多个线性层和激活函数组成,以下是一个简单的示例:

python 复制代码
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=x.shape[1], out_features=128),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=128, out_features=hot_dim)
)

这里 hot_dim 是类别数量,ReLU 作为激活函数用于引入非线性。

损失函数

多分类问题常用的损失函数是交叉熵损失函数(Cross Entropy Loss),定义如下:

python 复制代码
loss_fn = nn.CrossEntropyLoss()

评估指标

在模型评估时,将模型的输出 h 通过 argmax 函数找到每一行中最大值的索引,即预测的类别标签,h = h.argmax(-1)。然后计算准确率,公式为 acc = (h == y).float().mean()

通过以上对线性回归、逻辑回归和多分类任务的介绍,我们可以根据不同的问题类型选择合适的模型和处理方法,从而更好地解决实际问题。希望本文对你理解和应用这些模型有所帮助。

相关推荐
数字化转型20251 小时前
基于六大产品线+三项核心工作
程序人生·机器学习
汽车仪器仪表相关领域1 小时前
经典指针+瞬态追踪:MTX-A模拟废气温度(EGT)计 改装/赛车/柴油车排气温度监测实战全解
大数据·功能测试·算法·机器学习·可用性测试
HyperAI超神经1 小时前
软银/英伟达/红杉资本/贝佐斯等参投,机器人初创公司Skild AI融资14亿美元,打造通用基础模型
人工智能·深度学习·机器学习·机器人·ai编程
民乐团扒谱机1 小时前
机器学习 第二弹 和AI斗智斗勇 机器学习核心知识点全解析(GBDT/XGBoost/LightGBM/随机森林+调参方法)
算法·决策树·机器学习
charlie1145141912 小时前
机器学习概论:一门教计算机如何“不确定地正确”的学问
人工智能·笔记·机器学习·工程实践
Lun3866buzha2 小时前
【YOLO11-seg-RFCBAMConv】传送带状态检测与分类改进实现【含Python源码】
python·分类·数据挖掘
Echo_NGC22372 小时前
【联邦学习完全指南】Part 5:安全攻防与隐私保护
人工智能·深度学习·神经网络·安全·机器学习·联邦学习
清铎2 小时前
项目_华为杯’数模研赛复盘_第二问
深度学习·算法·机器学习
Allen_LVyingbo3 小时前
面向70B多模态医疗大模型预训练的工程落地(医疗大模型预训练扩展包)
人工智能·python·分类·知识图谱·健康医疗·迁移学习
杨_晨4 小时前
大模型微调训练FAQ - Loss与准确率关系
人工智能·经验分享·笔记·深度学习·机器学习·ai