机器学习中的回归与分类模型:线性回归、逻辑回归与多分类

在机器学习领域,回归和分类是两类重要的任务,它们各自有着不同的应用场景和模型构建方式。本文将详细介绍线性回归、逻辑回归以及多分类任务的相关内容,包括数据预处理、模型定义、损失函数的选择以及评估指标的计算。

一、线性回归(连续值预测)

线性回归是一种用于预测连续数值型目标变量的模型。在使用线性回归模型时,以下是一些关键步骤:

  1. 数据预处理

    对特征矩阵 X 和目标变量 y 都需要进行标准化或归一化处理。这样做的目的是为了使数据具有相同的尺度,从而提高模型训练的稳定性和效率。如果在生成特征数据时使用了时间窗口(此时 x 是三维的),需要对 x 进行形状重塑,即 x = x.reshape(-1, x.shape[1] * x.shape[2]);若没有使用时间窗口,则无需进行这一步骤。同时,目标变量 y 需要转换为二维数组,使用 y = y.reshape(-1, 1)

  2. 模型定义

    使用 torch 构建线性回归模型,定义如下:

    复制代码
    model = torch.nn.Linear(in_features=x.shape[1], out_features=1)

    其中 in_features 是输入特征的数量,out_features 设置为 1,因为我们预测的是单个连续值。

损失函数

线性回归通常使用均方误差损失函数(Mean Squared Error Loss,MSELoss),定义如下:

复制代码
import torch.nn as nn
loss_fn = nn.MSELoss()

二、逻辑回归(二分类任务)

逻辑回归是一种用于解决二分类问题的模型,它通过将线性回归的输出经过一个 Sigmoid 函数映射到 [0, 1] 区间,从而得到样本属于正类的概率。

  1. 数据预处理

    对特征矩阵 X 进行标准化或归一化处理。目标变量 y 需要转换为二维的 torch.Tensor 类型,并且数据类型为 float,使用 torch.tensor(y, dtype=torch.float).reshape(-1, 1)

  2. 模型定义

    逻辑回归模型由一个线性层和一个 Sigmoid 激活函数组成,定义如下:

    python 复制代码
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=x.shape[1], out_features=1),
        torch.nn.Sigmoid()
    )

    Sigmoid 函数将线性层的输出转换为概率值。

  3. 损失函数
    对于二分类问题,常用的损失函数是二元交叉熵损失函数(Binary Cross Entropy Loss,BCELoss),定义如下:

    python 复制代码
    loss_fn = nn.BCELoss()
  4. 评估指标
    在模型评估时,将模型的输出 h 经过阈值处理(通常阈值为 0.5)转换为类别标签,即 h = (h > 0.5).int()。然后计算准确率,公式为 acc = (h == y).float().mean()

三、多分类任务

多分类任务是指将样本分为多个不同的类别。

数据预处理

同样对特征矩阵 X 进行标准化或归一化处理。

模型定义

多分类模型通常由多个线性层和激活函数组成,以下是一个简单的示例:

python 复制代码
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=x.shape[1], out_features=128),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=128, out_features=hot_dim)
)

这里 hot_dim 是类别数量,ReLU 作为激活函数用于引入非线性。

损失函数

多分类问题常用的损失函数是交叉熵损失函数(Cross Entropy Loss),定义如下:

python 复制代码
loss_fn = nn.CrossEntropyLoss()

评估指标

在模型评估时,将模型的输出 h 通过 argmax 函数找到每一行中最大值的索引,即预测的类别标签,h = h.argmax(-1)。然后计算准确率,公式为 acc = (h == y).float().mean()

通过以上对线性回归、逻辑回归和多分类任务的介绍,我们可以根据不同的问题类型选择合适的模型和处理方法,从而更好地解决实际问题。希望本文对你理解和应用这些模型有所帮助。

相关推荐
二川bro17 小时前
AutoML自动化机器学习:Python实战指南
python·机器学习·自动化
大千AI助手20 小时前
概率单位回归(Probit Regression)详解
人工智能·机器学习·数据挖掘·回归·大千ai助手·概率单位回归·probit回归
我不是QI21 小时前
周志华《机器学习—西瓜书》二
人工智能·安全·机器学习
luoganttcc1 天前
RoboTron-Drive:自动驾驶领域的全能多模态大模型
人工智能·机器学习·自动驾驶
Ai173163915791 天前
2025.11.28国产AI计算卡参数信息汇总
服务器·图像处理·人工智能·神经网络·机器学习·视觉检测·transformer
青云交1 天前
Java 大视界 -- Java 大数据机器学习模型在电商评论情感分析与产品口碑优化中的应用
机器学习·自然语言处理·lstm·情感分析·java 大数据·电商评论·产品口碑
m0_372257021 天前
ID3 算法为什么可以用来优化决策树
算法·决策树·机器学习
Together_CZ1 天前
Cambrian-S: Towards Spatial Supersensing in Video——迈向视频中的空间超感知
人工智能·机器学习·音视频·spatial·cambrian-s·迈向视频中的空间超感知·supersensing
鼎道开发者联盟1 天前
智能原生操作系统畅想:人智共生新时代的基石
人工智能·机器学习·自然语言处理
lisw051 天前
6G频段与5G频段有何不同?
人工智能·机器学习