使用PyTorch实现逻辑回归:从训练到模型保存与加载

  1. 引入必要的库

首先,需要引入必要的库。PyTorch用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

2. 加载自定义数据集

有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

python 复制代码
# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签

# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 创建数据集和数据加载器

使用PyTorch的TensorDatasetDataLoader来创建数据集和数据加载器。

python 复制代码
# 创建数据集和数据加载器
dataset = TensorDataset(torch.tensor(X), torch.tensor(y))
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

4. 定义逻辑回归模型

使用PyTorch的nn.Module来定义逻辑回归模型。

python 复制代码
class LogisticRegression(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

# 初始化模型
input_dim = X.shape[1]
model = LogisticRegression(input_dim)

5. 训练模型

定义损失函数和优化器,然后训练模型。

python 复制代码
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs.flatten(), labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

6. 保存模型

训练完成后,可以使用PyTorch的torch.save函数来保存模型。

python 复制代码
# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')

7. 加载模型并进行预测

在需要时,可以使用torch.load函数加载模型,并进行预测。

python 复制代码
# 加载模型
model = LogisticRegression(input_dim)
model.load_state_dict(torch.load('logistic_regression_model.pth'))
model.eval()

# 进行预测
with torch.no_grad():
    sample_inputs = torch.tensor(X[:5]).float()  # 示例输入
    predictions = model(sample_inputs)
    predicted_labels = (predictions.flatten() > 0.5).int()

print("Predicted Labels:", predicted_labels.numpy())
相关推荐
Coder_Boy_1 小时前
技术发展的核心规律是「加法打底,减法优化,重构平衡」
人工智能·spring boot·spring·重构
会飞的老朱3 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º5 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
Codebee7 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º7 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys7 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56787 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子7 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能8 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144878 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能