Pytorch 8

这节课是讲mini_batch数据下载的

python 复制代码
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

第一个类是抽象类,只能继承

第二个可以直接用

python 复制代码
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0] # shape(多少行,多少列)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len

定义这个类要做两件事,第一件就是让他能下标调用,第二件事可以返回长度

数据下载

python 复制代码
dataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) 
# 接受四个参数,第一个是接受的,第二个是mini_batch的大小,第三是是否随机,第四是分为几个线程来下载数据

神经网络

python 复制代码
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batch
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
 
            optimizer.zero_grad()
            loss.backward()
 
            optimizer.step()

放在if里面是因为在windows系统里面会出错,i是下标

相关推荐
半青年40 分钟前
单例模式:全局唯一性在软件设计中的艺术实践
java·c++·python·单例模式
鸿蒙布道师44 分钟前
百度Create大会深度解读:AI Agent与多模态模型如何重塑未来?
人工智能·深度学习·神经网络·机器学习·百度·自然语言处理·dubbo
睿途低空新程1 小时前
面向城市治理的AI集群空域融合模型
人工智能·经验分享·其他·无人机
LaughingZhu1 小时前
PH热榜 | 2025-04-26
前端·数据库·人工智能·mysql·开源
海盗儿2 小时前
吴恩达深度学习作业之风格转移Neural Style Transfer (pytorch)
人工智能·计算机视觉
fen_fen2 小时前
Miniconda Windows10版本下载和安装
python
kyle~2 小时前
深度学习---Pytorch概览
人工智能·pytorch·python·深度学习
说私域2 小时前
开源AI智能名片链动2+1模式S2B2C商城小程序源码赋能下的社交电商创业者技能跃迁与价值重构
人工智能·小程序·重构·开源·零售
一点.点3 小时前
自动驾驶(ADAS)领域常用数据集介绍
人工智能·深度学习·机器学习·自动驾驶
智驱力人工智能3 小时前
夏季道路安全的AI革命:节省人力、提升效率
人工智能·安全·边缘计算·视觉算法·视觉分析·智能巡航·人工智能云计算