使用 PyTorch 实现逻辑回归:从数据到模型保存与加载

在机器学习中,逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将通过一个简单的示例,展示如何使用 PyTorch 框架实现逻辑回归模型,从数据准备到模型训练、保存和加载,最后进行预测。

1. 数据准备

逻辑回归的核心是通过学习数据中的特征与标签之间的关系来进行分类。在本示例中,我们手动创建了一个简单的二维数据集,包含两类数据点。第一类数据点的标签为 0,第二类数据点的标签为 1。

python 复制代码
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

我们将这两类数据点的特征和标签分别提取出来,并将它们合并到一个数据集中。特征数据 X 是一个二维数组,每一行表示一个数据点的两个特征值;标签数据 y 是一个一维数组,表示每个数据点对应的类别标签。

python 复制代码
# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)

# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)

2. 定义逻辑回归模型

逻辑回归模型的核心是一个线性变换后接一个 Sigmoid 激活函数。Sigmoid 函数可以将输出值映射到 (0, 1) 区间,从而表示为类别 1 的概率。我们使用 PyTorch 的 nn.Module 来定义模型。

python 复制代码
# 定义逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

3. 模型训练

训练模型需要定义损失函数和优化器。对于二分类问题,通常使用二分类交叉熵损失函数(BCELoss)。优化器我们选择随机梯度下降(SGD),学习率为 0.01。

python 复制代码
# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)

接下来,我们进行模型训练。训练过程包括前向传播、计算损失、反向传播和参数更新。我们训练 5000 个 epoch,并每 100 个 epoch 输出一次损失值,以观察模型的训练情况。

python 复制代码
# 训练模型
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

4. 模型保存与加载

训练完成后,我们将模型的参数保存到文件中,以便后续使用。PyTorch 提供了 torch.save 方法来保存模型参数。

python 复制代码
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存")

随后,我们加载保存的模型参数,并将其应用于新的模型实例中。加载模型时,我们使用 torch.load 方法,并指定 map_location 参数以确保模型可以在不同设备上加载。

python 复制代码
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth',
                                        map_location=torch.device('cpu'),
                                        weights_only=True))
loaded_model.eval()

5. 模型预测

最后,我们使用加载的模型对原始数据进行预测。预测时,我们使用 torch.no_grad() 上下文管理器来禁用梯度计算,以提高计算效率并减少内存占用。

python 复制代码
with torch.no_grad():
    predictions = loaded_model(X)
    predicted_labels = (predictions > 0.5).float()

print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())

通过比较实际结果和预测结果,我们可以直观地观察模型的分类效果。

总结

本文通过一个简单的二维数据集,展示了如何使用 PyTorch 实现逻辑回归模型,包括数据准备、模型定义、训练、保存、加载和预测的完整流程。逻辑回归是一种简单而有效的分类算法,适用于许多实际问题。通过本文的示例,希望大家可以快速掌握 PyTorch 的基本操作,并为进一步学习更复杂的深度学习模型打下基础。


完整代码

python 复制代码
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

"""自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测"""



# 提取特征和标签
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)

# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)

# 定义逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存")

# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth',
                                        map_location=torch.device('cpu'),
                                        weights_only=True))
loaded_model.eval()

# 进行预测
with torch.no_grad():
    predictions = loaded_model(X)
    predicted_labels = (predictions > 0.5).float()

 # 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())
相关推荐
hundaxxx1 小时前
自演化大语言模型的技术背景
人工智能
数智顾问1 小时前
【73页PPT】美的简单高效的管理逻辑(附下载方式)
大数据·人工智能·产品运营
love530love1 小时前
【保姆级教程】阿里 Wan2.1-T2V-14B 模型本地部署全流程:从环境配置到视频生成(附避坑指南)
人工智能·windows·python·开源·大模型·github·音视频
木头左1 小时前
结合机器学习的Backtrader跨市场交易策略研究
人工智能·机器学习·kotlin
Coovally AI模型快速验证2 小时前
3D目标跟踪重磅突破!TrackAny3D实现「类别无关」统一建模,多项SOTA达成!
人工智能·yolo·机器学习·3d·目标跟踪·无人机·cocos2d
研梦非凡2 小时前
CVPR 2025|基于粗略边界框监督的3D实例分割
人工智能·计算机网络·计算机视觉·3d
MiaoChuAI2 小时前
秒出PPT vs 豆包AI PPT:实测哪款更好用?
人工智能·powerpoint
fsnine2 小时前
深度学习——残差神经网路
人工智能·深度学习
和鲸社区3 小时前
《斯坦福CS336》作业1开源,从0手搓大模型|代码复现+免环境配置
人工智能·python·深度学习·计算机视觉·语言模型·自然语言处理·nlp
fanstuck3 小时前
2025 年高教社杯全国大学生数学建模竞赛C 题 NIPT 的时点选择与胎儿的异常判定详解(一)
人工智能·目标检测·数学建模·数据挖掘·aigc