价格分类(神经网络)

复制代码
# 1.导入依赖包
import time

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from torchsummary import summary


# 2.构建数据集
def create_dataset():
    # 2.1 读取数据集
    data = pd.read_csv('dataset/手机价格预测.csv')

    # 2.2 获取特征值和目标值,类型转化  特征(Float)  标签(Long)
    x, y = data.iloc[:, :-1], data.iloc[:, -1]
    x, y = x.astype(np.float32), y.astype(np.int64)

    # 2.3 数据集划分
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2,random_state=2)

    # 2.4 数据转Tensor
    train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))
    test_dataset = TensorDataset(torch.from_numpy(x_test.values), torch.tensor(y_test.values))

    return train_dataset, test_dataset, x_train.shape[1], len(np.unique(y))


# 3. 构建模型
class PhonePriceModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PhonePriceModel, self).__init__()
        self.linear1 = nn.Linear(input_dim, 256)
        self.linear2 = nn.Linear(256, 1024)
        self.fc = nn.Linear(1024, output_dim)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        output = self.fc(x)
        # output = torch.softmax(self.fc(x), dim=-1)

        return output


# 4.模型训练(225)
def train(model, train_dataset, num_epochs, batch_size):
    # 2 初始化参数  损失函数  优化器
    loss1 = nn.CrossEntropyLoss()
    # optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.99, 0.99))

    start = time.time()

    # 2 2个遍历  epoch  dataloader
    for epoch in range(num_epochs):
        dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

        total_num = 0
        total_loss = 0.0
        for x, y in dataloader:
            # 5 前向传播  损失计算 梯度归零  反向传播 参数更新
            output = model(x)
            loss = loss1(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_num += 1  # 批次
            total_loss += loss.item()

        epoch += 1

        print(f'epoch:{epoch + 1:4d},loss:{total_loss / (total_num * epoch):.4f}, time:{time.time() - start:.2f}s')
    # 模型持久化
    torch.save(model.state_dict(), 'model/phone2.pth')


# 5.模型预测评估
def test(model, test_dataset, input_dim, output_dim):
    # 3.导入数据
    dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

    correct = 0
    # 4.遍历数据
    for x, y in dataloader:
        # 4.1 前向传播
        output = model(x)
        print(output)
        # 4.2 获取输出结果(类别)
        y_pred = torch.argmax(output, dim=1)
        # print(y_pred)  # 预测错误
        # 4.3 计算准确率Acc
        correct += (y_pred == y).sum()
        print(correct.item())
    Acc = correct.item() / len(test_dataset)

    return Acc


if __name__ == '__main__':
    train_dataset, test_dataset, feature_num, label_num = create_dataset()
    # 1.实例化模型
    model = PhonePriceModel(feature_num, label_num)
    # 2.加载模型
    model.load_state_dict(torch.load('model/phone2.pth'))
    # 模型训练
    # train(model, train_dataset, num_epochs=50, batch_size=8)

    # 模型预测
    Acc = test(model, test_dataset, feature_num, label_num)
    print(f'Acc:{Acc:.5f}')
相关推荐
一个散步者的梦7 小时前
一键生成数据分析报告:Python的ydata-profiling模块(汉化)
python·数据挖掘·数据分析
生成论实验室12 小时前
周林东的生成论入门十讲 · 第八讲 生成的世界——物理学与生物学新视角
人工智能·科技·神经网络·信息与通信·几何学
小王毕业啦14 小时前
2000-2023年 地级市-公路运输相关数据
大数据·人工智能·数据挖掘·数据分析·数据统计·社科数据·实证数据
熊猫钓鱼>_>16 小时前
TensorFlow深度学习框架入门浅析
深度学习·神经网络·tensorflow·neo4j·张量·训练模型·评估模型
科学最TOP16 小时前
IJCAI25|如何平衡文本与时序信息的融合适配?
人工智能·深度学习·神经网络·机器学习·时间序列
yzx99101318 小时前
基于Flask+Vue.js的智能社区垃圾分类管理系统 - 三创赛参赛项目全栈开发指南
vue.js·分类·flask
私人珍藏库19 小时前
[吾爱大神原创工具] 照片视频整理工具 V1.0
windows·分类·工具·整理·照片·辅助
勤劳的进取家20 小时前
论文阅读:农业喷雾无人机避障技术综述
论文阅读·嵌入式硬件·神经网络·计算机视觉·无人机
geneculture20 小时前
融合全部讨论精华的融智学认知与实践总览图:掌握在复杂世界中锚定自我、有效行动、并参与塑造近未来的元能力
大数据·人工智能·数据挖掘·信息科学·融智学的重要应用·信智序位·全球软件定位系统
永霖光电_UVLED20 小时前
安森美与英诺赛科将合作推进氮化镓(GaN)功率器件的量产应用
人工智能·神经网络·生成对抗网络