机器学习中的回归与分类模型:线性回归、逻辑回归与多分类

在机器学习领域,回归和分类是两类重要的任务,它们各自有着不同的应用场景和模型构建方式。本文将详细介绍线性回归、逻辑回归以及多分类任务的相关内容,包括数据预处理、模型定义、损失函数的选择以及评估指标的计算。

一、线性回归(连续值预测)

线性回归是一种用于预测连续数值型目标变量的模型。在使用线性回归模型时,以下是一些关键步骤:

  1. 数据预处理

    对特征矩阵 X 和目标变量 y 都需要进行标准化或归一化处理。这样做的目的是为了使数据具有相同的尺度,从而提高模型训练的稳定性和效率。如果在生成特征数据时使用了时间窗口(此时 x 是三维的),需要对 x 进行形状重塑,即 x = x.reshape(-1, x.shape[1] * x.shape[2]);若没有使用时间窗口,则无需进行这一步骤。同时,目标变量 y 需要转换为二维数组,使用 y = y.reshape(-1, 1)

  2. 模型定义

    使用 torch 构建线性回归模型,定义如下:

    复制代码
    model = torch.nn.Linear(in_features=x.shape[1], out_features=1)

    其中 in_features 是输入特征的数量,out_features 设置为 1,因为我们预测的是单个连续值。

损失函数

线性回归通常使用均方误差损失函数(Mean Squared Error Loss,MSELoss),定义如下:

复制代码
import torch.nn as nn
loss_fn = nn.MSELoss()

二、逻辑回归(二分类任务)

逻辑回归是一种用于解决二分类问题的模型,它通过将线性回归的输出经过一个 Sigmoid 函数映射到 [0, 1] 区间,从而得到样本属于正类的概率。

  1. 数据预处理

    对特征矩阵 X 进行标准化或归一化处理。目标变量 y 需要转换为二维的 torch.Tensor 类型,并且数据类型为 float,使用 torch.tensor(y, dtype=torch.float).reshape(-1, 1)

  2. 模型定义

    逻辑回归模型由一个线性层和一个 Sigmoid 激活函数组成,定义如下:

    python 复制代码
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=x.shape[1], out_features=1),
        torch.nn.Sigmoid()
    )

    Sigmoid 函数将线性层的输出转换为概率值。

  3. 损失函数
    对于二分类问题,常用的损失函数是二元交叉熵损失函数(Binary Cross Entropy Loss,BCELoss),定义如下:

    python 复制代码
    loss_fn = nn.BCELoss()
  4. 评估指标
    在模型评估时,将模型的输出 h 经过阈值处理(通常阈值为 0.5)转换为类别标签,即 h = (h > 0.5).int()。然后计算准确率,公式为 acc = (h == y).float().mean()

三、多分类任务

多分类任务是指将样本分为多个不同的类别。

数据预处理

同样对特征矩阵 X 进行标准化或归一化处理。

模型定义

多分类模型通常由多个线性层和激活函数组成,以下是一个简单的示例:

python 复制代码
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=x.shape[1], out_features=128),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=128, out_features=hot_dim)
)

这里 hot_dim 是类别数量,ReLU 作为激活函数用于引入非线性。

损失函数

多分类问题常用的损失函数是交叉熵损失函数(Cross Entropy Loss),定义如下:

python 复制代码
loss_fn = nn.CrossEntropyLoss()

评估指标

在模型评估时,将模型的输出 h 通过 argmax 函数找到每一行中最大值的索引,即预测的类别标签,h = h.argmax(-1)。然后计算准确率,公式为 acc = (h == y).float().mean()

通过以上对线性回归、逻辑回归和多分类任务的介绍,我们可以根据不同的问题类型选择合适的模型和处理方法,从而更好地解决实际问题。希望本文对你理解和应用这些模型有所帮助。

相关推荐
acstdm5 小时前
DAY 48 CBAM注意力
人工智能·深度学习·机器学习
摸爬滚打李上进5 小时前
重生学AI第十六集:线性层nn.Linear
人工智能·pytorch·python·神经网络·机器学习
asyxchenchong8886 小时前
ChatGPT、DeepSeek等大语言模型助力高效办公、论文与项目撰写、数据分析、机器学习与深度学习建模
机器学习·语言模型·chatgpt
BFT白芙堂8 小时前
睿尔曼系列机器人——以创新驱动未来,重塑智能协作新生态(上)
人工智能·机器学习·机器人·协作机器人·复合机器人·睿尔曼机器人
羊小猪~~9 小时前
【NLP入门系列五】中文文本分类案例
人工智能·深度学习·考研·机器学习·自然语言处理·分类·数据挖掘
李师兄说大模型9 小时前
KDD 2025 | 地理定位中的群体智能:一个多智能体大型视觉语言模型协同框架
人工智能·深度学习·机器学习·语言模型·自然语言处理·大模型·deepseek
网安INF9 小时前
深层神经网络:原理与传播机制详解
人工智能·深度学习·神经网络·机器学习
超龄超能程序猿11 小时前
(1)机器学习小白入门 YOLOv:从概念到实践
人工智能·机器学习
.30-06Springfield19 小时前
人工智能概念之七:集成学习思想(Bagging、Boosting、Stacking)
人工智能·算法·机器学习·集成学习