使用 PyTorch 实现逻辑回归:从数据到模型保存与加载

在机器学习中,逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将通过一个简单的示例,展示如何使用 PyTorch 框架实现逻辑回归模型,从数据准备到模型训练、保存和加载,最后进行预测。

1. 数据准备

逻辑回归的核心是通过学习数据中的特征与标签之间的关系来进行分类。在本示例中,我们手动创建了一个简单的二维数据集,包含两类数据点。第一类数据点的标签为 0,第二类数据点的标签为 1。

python 复制代码
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

我们将这两类数据点的特征和标签分别提取出来,并将它们合并到一个数据集中。特征数据 X 是一个二维数组,每一行表示一个数据点的两个特征值;标签数据 y 是一个一维数组,表示每个数据点对应的类别标签。

python 复制代码
# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)

# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)

2. 定义逻辑回归模型

逻辑回归模型的核心是一个线性变换后接一个 Sigmoid 激活函数。Sigmoid 函数可以将输出值映射到 (0, 1) 区间,从而表示为类别 1 的概率。我们使用 PyTorch 的 nn.Module 来定义模型。

python 复制代码
# 定义逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

3. 模型训练

训练模型需要定义损失函数和优化器。对于二分类问题,通常使用二分类交叉熵损失函数(BCELoss)。优化器我们选择随机梯度下降(SGD),学习率为 0.01。

python 复制代码
# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)

接下来,我们进行模型训练。训练过程包括前向传播、计算损失、反向传播和参数更新。我们训练 5000 个 epoch,并每 100 个 epoch 输出一次损失值,以观察模型的训练情况。

python 复制代码
# 训练模型
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

4. 模型保存与加载

训练完成后,我们将模型的参数保存到文件中,以便后续使用。PyTorch 提供了 torch.save 方法来保存模型参数。

python 复制代码
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存")

随后,我们加载保存的模型参数,并将其应用于新的模型实例中。加载模型时,我们使用 torch.load 方法,并指定 map_location 参数以确保模型可以在不同设备上加载。

python 复制代码
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth',
                                        map_location=torch.device('cpu'),
                                        weights_only=True))
loaded_model.eval()

5. 模型预测

最后,我们使用加载的模型对原始数据进行预测。预测时,我们使用 torch.no_grad() 上下文管理器来禁用梯度计算,以提高计算效率并减少内存占用。

python 复制代码
with torch.no_grad():
    predictions = loaded_model(X)
    predicted_labels = (predictions > 0.5).float()

print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())

通过比较实际结果和预测结果,我们可以直观地观察模型的分类效果。

总结

本文通过一个简单的二维数据集,展示了如何使用 PyTorch 实现逻辑回归模型,包括数据准备、模型定义、训练、保存、加载和预测的完整流程。逻辑回归是一种简单而有效的分类算法,适用于许多实际问题。通过本文的示例,希望大家可以快速掌握 PyTorch 的基本操作,并为进一步学习更复杂的深度学习模型打下基础。


完整代码

python 复制代码
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

"""自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测"""



# 提取特征和标签
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)

# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)

# 定义逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存")

# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth',
                                        map_location=torch.device('cpu'),
                                        weights_only=True))
loaded_model.eval()

# 进行预测
with torch.no_grad():
    predictions = loaded_model(X)
    predicted_labels = (predictions > 0.5).float()

 # 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())
相关推荐
Roc_z740 分钟前
Facebook 元宇宙与全球文化交流的新趋势
人工智能·智能合约·facebook
AIQL4 小时前
Deepseek的RL算法GRPO解读
人工智能·算法·机器学习·deepseek·grpo算法
xiaomu_3475 小时前
2024-2025自动驾驶技术演进与产业破局的深度实践——一名自动驾驶算法工程师的年度技术总结与行业洞察
linux·人工智能·自动驾驶
終不似少年遊*6 小时前
NLP自然语言处理通识
人工智能·python·自然语言处理
courniche6 小时前
神经网络的通俗介绍
人工智能·神经网络·算法
wangzaojun6 小时前
睡眠时间影响因素K-Means可视化分析+XGBoost预测
人工智能·机器学习·kmeans
weixin_307779136 小时前
TensorFlow 2基本功能和示例代码
人工智能·深度学习·tensorflow
青松@FasterAI6 小时前
【ChatGPT】意义空间与语义运动定律 —— AI 世界的神秘法则
人工智能·深度学习·chatgpt
FF-Studio7 小时前
【DeepSeek】LLM强化学习GRPO Trainer详解
人工智能·机器学习
Icomi_8 小时前
【PyTorch】5.张量索引操作
人工智能·pytorch·python·深度学习·神经网络·机器学习·计算机视觉