自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, label
# 示例数据
data = np.random.rand(100, 10)  # 100个样本,每个样本10个特征
labels = np.random.randint(0, 2, 100)  # 二分类标签

# 转换为torch的Tensor
data = torch.tensor(data, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

# 创建数据集和数据加载器
dataset = CustomDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)


class LogisticRegression(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        out = torch.sigmoid(self.linear(x))
        return out


# 参数设置
input_dim = data.shape[1]
num_epochs = 20
learning_rate = 0.01

# 初始化模型、损失函数和优化器
model = LogisticRegression(input_dim)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for inputs, labels in dataloader:
        # 将标签转换为浮点型
        labels = labels.float().unsqueeze(1)

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
# 保存模型权重和架构
torch.save(model.state_dict(), 'logistic_regression_model.pth')

# 加载模型
loaded_model = LogisticRegression(input_dim)
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换为评估模式

# 进行预测
with torch.no_grad():
    sample = torch.tensor(np.random.rand(1, 10), dtype=torch.float32)  # 一个新的样本
    prediction = loaded_model(sample)
    print(f'Prediction: {prediction.item()}')
相关推荐
愚公搬代码几秒前
【愚公系列】《AI短视频创作一本通》016-AI短视频的生成(AI短视频运镜方法)
人工智能·音视频
哈__几秒前
CANN内存管理与资源优化
人工智能·pytorch
极新1 分钟前
智启新篇,智创未来,“2026智造新IP:AI驱动品牌增长新周期”峰会暨北京电子商务协会第五届第三次会员代表大会成功举办
人工智能·网络协议·tcp/ip
island13143 分钟前
CANN GE(图引擎)深度解析:计算图优化管线、内存静态规划与异构任务的 Stream 调度机制
开发语言·人工智能·深度学习·神经网络
艾莉丝努力练剑3 分钟前
深度学习视觉任务:如何基于ops-cv定制图像预处理流程
人工智能·深度学习
禁默8 分钟前
大模型推理的“氮气加速系统”:全景解读 Ascend Transformer Boost (ATB)
人工智能·深度学习·transformer·cann
User_芊芊君子9 分钟前
CANN大模型加速核心ops-transformer全面解析:Transformer架构算子的高性能实现与优化
人工智能·深度学习·transformer
格林威9 分钟前
Baumer相机玻璃制品裂纹自动检测:提高透明材质检测精度的 6 个关键步骤,附 OpenCV+Halcon 实战代码!
人工智能·opencv·视觉检测·材质·工业相机·sdk开发·堡盟相机
点云SLAM11 分钟前
Concentrate 英文单词学习
人工智能·英文单词学习·雅思备考·concentrate·集中·浓缩 / 集中物
哈__11 分钟前
CANN轻量化开发实战:快速上手与多场景适配
人工智能