Pytorch 8

这节课是讲mini_batch数据下载的

python 复制代码
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

第一个类是抽象类,只能继承

第二个可以直接用

python 复制代码
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0] # shape(多少行,多少列)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len

定义这个类要做两件事,第一件就是让他能下标调用,第二件事可以返回长度

数据下载

python 复制代码
dataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) 
# 接受四个参数,第一个是接受的,第二个是mini_batch的大小,第三是是否随机,第四是分为几个线程来下载数据

神经网络

python 复制代码
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batch
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
 
            optimizer.zero_grad()
            loss.backward()
 
            optimizer.step()

放在if里面是因为在windows系统里面会出错,i是下标

相关推荐
谅望者2 分钟前
数据分析笔记08:Python编程基础-数据类型与变量
数据库·笔记·python·数据分析·概率论
CV实验室3 分钟前
2025 | 哈工大&鹏城实验室等提出 Cascade HQP-DETR:仅用合成数据实现SOTA目标检测,突破虚实鸿沟!
人工智能·目标检测·计算机视觉·哈工大
aitoolhub5 分钟前
培训ppt高效制作:稿定设计 + Prompt 工程 30 分钟出图指南
人工智能·prompt·aigc
oranglay5 分钟前
提示词(Prompt Engineering)核心思维
人工智能·prompt
mortimer7 分钟前
【实战复盘】 PySide6 + PyTorch 偶发性“假死”?由多线程转多进程
pytorch·python·pyqt
极速learner8 分钟前
【Prompt分享】自学英语教程的AI 提示语:流程、范例及可视化实现
人工智能·prompt·ai写作
清静诗意8 分钟前
Django REST Framework(DRF)RESTful 最完整版实战教程
python·django·restful·drf
大怪v15 分钟前
我TM被AI骗的自己PUA了自己😂 😂 !细思极恐~
人工智能·chatgpt·grok
studytosky34 分钟前
深度学习理论与实战:Pytorch基础入门
人工智能·pytorch·python·深度学习·机器学习
沫儿笙42 分钟前
安川YASKAWA焊接机器人电池拖盘焊接节气
人工智能·机器人