机器学习中的回归与分类模型:线性回归、逻辑回归与多分类

在机器学习领域,回归和分类是两类重要的任务,它们各自有着不同的应用场景和模型构建方式。本文将详细介绍线性回归、逻辑回归以及多分类任务的相关内容,包括数据预处理、模型定义、损失函数的选择以及评估指标的计算。

一、线性回归(连续值预测)

线性回归是一种用于预测连续数值型目标变量的模型。在使用线性回归模型时,以下是一些关键步骤:

  1. 数据预处理

    对特征矩阵 X 和目标变量 y 都需要进行标准化或归一化处理。这样做的目的是为了使数据具有相同的尺度,从而提高模型训练的稳定性和效率。如果在生成特征数据时使用了时间窗口(此时 x 是三维的),需要对 x 进行形状重塑,即 x = x.reshape(-1, x.shape[1] * x.shape[2]);若没有使用时间窗口,则无需进行这一步骤。同时,目标变量 y 需要转换为二维数组,使用 y = y.reshape(-1, 1)

  2. 模型定义

    使用 torch 构建线性回归模型,定义如下:

    复制代码
    model = torch.nn.Linear(in_features=x.shape[1], out_features=1)

    其中 in_features 是输入特征的数量,out_features 设置为 1,因为我们预测的是单个连续值。

损失函数

线性回归通常使用均方误差损失函数(Mean Squared Error Loss,MSELoss),定义如下:

复制代码
import torch.nn as nn
loss_fn = nn.MSELoss()

二、逻辑回归(二分类任务)

逻辑回归是一种用于解决二分类问题的模型,它通过将线性回归的输出经过一个 Sigmoid 函数映射到 [0, 1] 区间,从而得到样本属于正类的概率。

  1. 数据预处理

    对特征矩阵 X 进行标准化或归一化处理。目标变量 y 需要转换为二维的 torch.Tensor 类型,并且数据类型为 float,使用 torch.tensor(y, dtype=torch.float).reshape(-1, 1)

  2. 模型定义

    逻辑回归模型由一个线性层和一个 Sigmoid 激活函数组成,定义如下:

    python 复制代码
    model = torch.nn.Sequential(
        torch.nn.Linear(in_features=x.shape[1], out_features=1),
        torch.nn.Sigmoid()
    )

    Sigmoid 函数将线性层的输出转换为概率值。

  3. 损失函数
    对于二分类问题,常用的损失函数是二元交叉熵损失函数(Binary Cross Entropy Loss,BCELoss),定义如下:

    python 复制代码
    loss_fn = nn.BCELoss()
  4. 评估指标
    在模型评估时,将模型的输出 h 经过阈值处理(通常阈值为 0.5)转换为类别标签,即 h = (h > 0.5).int()。然后计算准确率,公式为 acc = (h == y).float().mean()

三、多分类任务

多分类任务是指将样本分为多个不同的类别。

数据预处理

同样对特征矩阵 X 进行标准化或归一化处理。

模型定义

多分类模型通常由多个线性层和激活函数组成,以下是一个简单的示例:

python 复制代码
model = torch.nn.Sequential(
    torch.nn.Linear(in_features=x.shape[1], out_features=128),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=128, out_features=hot_dim)
)

这里 hot_dim 是类别数量,ReLU 作为激活函数用于引入非线性。

损失函数

多分类问题常用的损失函数是交叉熵损失函数(Cross Entropy Loss),定义如下:

python 复制代码
loss_fn = nn.CrossEntropyLoss()

评估指标

在模型评估时,将模型的输出 h 通过 argmax 函数找到每一行中最大值的索引,即预测的类别标签,h = h.argmax(-1)。然后计算准确率,公式为 acc = (h == y).float().mean()

通过以上对线性回归、逻辑回归和多分类任务的介绍,我们可以根据不同的问题类型选择合适的模型和处理方法,从而更好地解决实际问题。希望本文对你理解和应用这些模型有所帮助。

相关推荐
B站_计算机毕业设计之家4 分钟前
机器学习实战项目:Python+Flask 汽车销量分析可视化系统(requests爬车主之家+可视化 源码+文档)✅
人工智能·python·机器学习·数据分析·flask·汽车·可视化
lucky_syq3 小时前
解锁特征工程:机器学习的秘密武器
人工智能·机器学习
CM莫问3 小时前
推荐算法之粗排
深度学习·算法·机器学习·数据挖掘·排序算法·推荐算法·粗排
rengang664 小时前
10-支持向量机(SVM):讲解基于最大间隔原则的分类算法
人工智能·算法·机器学习·支持向量机
周杰伦_Jay4 小时前
【Git操作详解】Git进行版本控制与管理,包括分支,提交,合并,标签、远程仓库查看
大数据·ide·git·科技·分类·github
on_pluto_5 小时前
LLaMA: Open and Efficient Foundation Language Models 论文阅读
python·机器学习
antonytyler6 小时前
认识机器学习
机器学习
一车小面包6 小时前
对注意力机制的直观理解
人工智能·深度学习·机器学习
听风吹等浪起7 小时前
分类算法-逻辑回归
人工智能·算法·机器学习
成都犀牛8 小时前
强化学习(5)多智能体强化学习
人工智能·机器学习·强化学习