机器学习中的回归与分类模型:线性回归、逻辑回归与多分类

在机器学习领域,回归和分类是两类重要的任务,它们各自有着不同的应用场景和模型构建方式。本文将详细介绍线性回归、逻辑回归以及多分类任务的相关内容,包括数据预处理、模型定义、损失函数的选择以及评估指标的计算。

一、线性回归(连续值预测)

线性回归是一种用于预测连续数值型目标变量的模型。在使用线性回归模型时,以下是一些关键步骤:

  1. 数据预处理

    对特征矩阵 X 和目标变量 y 都需要进行标准化或归一化处理。这样做的目的是为了使数据具有相同的尺度,从而提高模型训练的稳定性和效率。如果在生成特征数据时使用了时间窗口(此时 x 是三维的),需要对 x 进行形状重塑,即 x = x.reshape(-1, x.shape[1] * x.shape[2]);若没有使用时间窗口,则无需进行这一步骤。同时,目标变量 y 需要转换为二维数组,使用 y = y.reshape(-1, 1)

  2. 模型定义

    使用 torch 构建线性回归模型,定义如下:

    复制代码
    model = torch.nn.Linear(in_features=x.shape[1], out_features=1)

    其中 in_features 是输入特征的数量,out_features 设置为 1,因为我们预测的是单个连续值。

损失函数

线性回归通常使用均方误差损失函数(Mean Squared Error Loss,MSELoss),定义如下:

复制代码
import torch.nn as nn
loss_fn = nn.MSELoss()

二、逻辑回归(二分类任务)

逻辑回归是一种用于解决二分类问题的模型,它通过将线性回归的输出经过一个 Sigmoid 函数映射到 [0, 1] 区间,从而得到样本属于正类的概率。

  1. 数据预处理

    对特征矩阵 X 进行标准化或归一化处理。目标变量 y 需要转换为二维的 torch.Tensor 类型,并且数据类型为 float,使用 torch.tensor(y, dtype=torch.float).reshape(-1, 1)

  2. 模型定义

    逻辑回归模型由一个线性层和一个 Sigmoid 激活函数组成,定义如下:

    python 复制代码
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=x.shape[1], out_features=1),
        torch.nn.Sigmoid()
    )

    Sigmoid 函数将线性层的输出转换为概率值。

  3. 损失函数
    对于二分类问题,常用的损失函数是二元交叉熵损失函数(Binary Cross Entropy Loss,BCELoss),定义如下:

    python 复制代码
    loss_fn = nn.BCELoss()
  4. 评估指标
    在模型评估时,将模型的输出 h 经过阈值处理(通常阈值为 0.5)转换为类别标签,即 h = (h > 0.5).int()。然后计算准确率,公式为 acc = (h == y).float().mean()

三、多分类任务

多分类任务是指将样本分为多个不同的类别。

数据预处理

同样对特征矩阵 X 进行标准化或归一化处理。

模型定义

多分类模型通常由多个线性层和激活函数组成,以下是一个简单的示例:

python 复制代码
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=x.shape[1], out_features=128),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=128, out_features=hot_dim)
)

这里 hot_dim 是类别数量,ReLU 作为激活函数用于引入非线性。

损失函数

多分类问题常用的损失函数是交叉熵损失函数(Cross Entropy Loss),定义如下:

python 复制代码
loss_fn = nn.CrossEntropyLoss()

评估指标

在模型评估时,将模型的输出 h 通过 argmax 函数找到每一行中最大值的索引,即预测的类别标签,h = h.argmax(-1)。然后计算准确率,公式为 acc = (h == y).float().mean()

通过以上对线性回归、逻辑回归和多分类任务的介绍,我们可以根据不同的问题类型选择合适的模型和处理方法,从而更好地解决实际问题。希望本文对你理解和应用这些模型有所帮助。

相关推荐
AI_56783 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
Liue612312314 小时前
YOLO11-C3k2-MBRConv3改进提升金属表面缺陷检测与分类性能_焊接裂纹气孔飞溅物焊接线识别
人工智能·分类·数据挖掘
小鸡吃米…6 小时前
机器学习的商业化变现
人工智能·机器学习
Lun3866buzha6 小时前
农业害虫检测_YOLO11-C3k2-EMSC模型实现与分类识别_1
人工智能·分类·数据挖掘
木非哲8 小时前
机器学习--随机森林--从一棵树的直觉到一片林的哲学
人工智能·随机森林·机器学习
A尘埃10 小时前
保险公司车险理赔欺诈检测(随机森林)
算法·随机森林·机器学习
小瑞瑞acd14 小时前
【小瑞瑞精讲】卷积神经网络(CNN):从入门到精通,计算机如何“看”懂世界?
人工智能·python·深度学习·神经网络·机器学习
民乐团扒谱机14 小时前
【微实验】机器学习之集成学习 GBDT和XGBoost 附 matlab仿真代码 复制即可运行
人工智能·机器学习·matlab·集成学习·xgboost·gbdt·梯度提升树
机器学习之心15 小时前
TCN-Transformer-BiGRU组合模型回归+SHAP分析+新数据预测+多输出!深度学习可解释分析
深度学习·回归·transformer·shap分析
Σίσυφος190015 小时前
PCL法向量估计 之 RANSAC 平面估计法向量
算法·机器学习·平面