自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

python 复制代码
import torch
import numpy as np
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

data = [[-0.5, 7.7],
        [1.8, 98.5],
        [0.9, 57.8],
        [0.4, 39.2],
        [-1.4, -15.7],
        [-1.4, -37.3],
        [-1.8, -49.1],
        [1.5, 75.6],
        [0.4, 34.0],
        [0.8, 62.3]]

# 将数据转为 numpy 数组
data = np.array(data)

# 提取 x_data 和 y_data
x_data = data[:, 0]  # 取第一列作为输入特征
y_data = data[:, 1]  # 取第二列作为目标标签

# 将数据转换为 PyTorch 张量
x_train = torch.tensor(x_data, dtype=torch.float32)  # 输入特征
y_train = torch.tensor(y_data, dtype=torch.float32)  # 目标标签

# 使用 TensorDataset 来创建一个数据集
from torch.utils.data import DataLoader, TensorDataset

dataset = TensorDataset(x_train, y_train)  # 使用训练数据创建数据集
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)  # 将数据集转换为 DataLoader,批大小为 2,且每个 epoch 都会随机打乱数据

# 定义损失函数:均方误差损失 (MSELoss)
criterion = nn.MSELoss()


# 定义线性回归模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # 使用一个线性层,输入为1维,输出为1维
        self.layers = nn.Linear(1, 1)

    def forward(self, x):
        # 直接返回线性层的输出
        return self.layers(x)


model = LinearModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoches = 500
for n in range(1, epoches + 1):
    epoch_loss = 0
    # 以前都是所有数据一块训练,现在是按照批次进行训练
    for batch_x, batch_y in dataloader:
        # 现在x_train 相当于10个样本,但是现在维度,添加一个维度
        # 10x1   变成样本 x 维度形式
        y_prd = model(batch_x.unsqueeze(1))
        # 计算损失
        # y_prd在前面,y_true 是后面
        batch_loss = criterion(y_prd.squeeze(1), batch_y)
        # 梯度更新
        # 清空之前存储在优化器中的梯度
        optimizer.zero_grad()
        # 损失函数对模型参数的梯度
        batch_loss.backward()
        # 根据优化算法更新参数
        optimizer.step()
        # 计算一下epoch的损失
        epoch_loss = epoch_loss + batch_loss

        # 5、显示频率设置

    # 计算一下epoch的平均损失
    avg_loss = epoch_loss / (len(dataloader))
    # 不先画图
    if n % 10 == 0 or n == 1:
        print(f"epoches:{n},loss:{avg_loss}")
        torch.save(model.state_dict(), 'model.pth')

model.load_state_dict(torch.load("model.pth"))
# 评估模型
# 评估模型一定要加下面这句话
model.eval()
# 定义数据
x_test = torch.tensor([[1.8]], dtype=torch.float32)
# 添加上下文不需要计算梯度
with torch.no_grad():
    y_pred = model(x_test)

threshold = 50  # 设定阈值
y_pred_class = int(y_pred.item() > threshold)

# 输出预测结果
print(f"预测值 : {y_pred.item():.4f}")
print(f"预测类 : {y_pred_class}")

# 假设真实标签也是 1 或 0,我们用一个假的真实标签来计算评估指标(你可以根据实际情况替换)
y_true_class = 1 if y_data[1] > threshold else 0  # 假设我们预测的是第二个样本

# 计算精确度、召回率和 F1 分数
accuracy = accuracy_score([y_true_class], [y_pred_class])
precision = precision_score([y_true_class], [y_pred_class])
recall = recall_score([y_true_class], [y_pred_class])
f1 = f1_score([y_true_class], [y_pred_class])

# 输出分类评估指标
print(f"precision : {precision:.4f}")
print(f"recall : {recall:.4f}")
print(f"f1 : {f1:.4f}")
相关推荐
步辞43 分钟前
Go语言怎么用channel做信号通知_Go语言channel信号模式教程【完整】
jvm·数据库·python
Ulyanov44 分钟前
《PySide6 GUI开发指南:QML核心与实践》 第一篇:GUI新纪元——QML与PySide6生态系统全景
开发语言·python·qt·qml·雷达电子对抗
曲幽1 小时前
FastAPI + SQLAlchemy 2.0 通用CRUD操作手册 —— 从同步到异步,一次讲透
python·fastapi·web·async·sqlalchemy·session·crud·sync·with
Dxy12393102161 小时前
Python 如何使用 XPath 定位元素:从入门到实战
python
用户8356290780511 小时前
Python 设置 PowerPoint 文档属性与页面参数
后端·python
weixin_424999361 小时前
mysql行级锁失效的原因排查_检查查询条件与执行计划
jvm·数据库·python
yaoxin5211231 小时前
389. Java IO API - 获取文件名
java·开发语言·python
Polar__Star1 小时前
uni-app怎么实现App端一键换肤 uni-app全局样式动态切换【实战】
jvm·数据库·python
用户8356290780512 小时前
使用 Python 自动管理 PowerPoint 幻灯片分节的方法
后端·python
奇牙3 小时前
DeepSeek V4 Agent 开发实战:用 deepseek-v4-pro 搭建多步骤工作流(2026 完整代码)
python