Pytorch 8

这节课是讲mini_batch数据下载的

python 复制代码
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

第一个类是抽象类,只能继承

第二个可以直接用

python 复制代码
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0] # shape(多少行,多少列)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len

定义这个类要做两件事,第一件就是让他能下标调用,第二件事可以返回长度

数据下载

python 复制代码
dataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) 
# 接受四个参数,第一个是接受的,第二个是mini_batch的大小,第三是是否随机,第四是分为几个线程来下载数据

神经网络

python 复制代码
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batch
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
 
            optimizer.zero_grad()
            loss.backward()
 
            optimizer.step()

放在if里面是因为在windows系统里面会出错,i是下标

相关推荐
闵孚龙3 分钟前
Claude Code 工具提示词全拆解:AI Agent、Prompt Engineering、工具调用、上下文工程、自动化编程的底层逻辑
人工智能·自动化·prompt
白鲸开源6 分钟前
杀疯了!SeaTunnel AI CLI 解锁数据集成新玩法
大数据·人工智能·github
輕華8 分钟前
uv工具详解——Python包与项目管理器完全指南
开发语言·python·uv
王_teacher9 分钟前
GRU (Gated Recurrent Unit,门控循环单元) 原理详解 并且手写GRU模型
人工智能·gru·llm·nlp
li星野9 分钟前
位运算 & 数学 & 高频进阶九题通关(Python + C++)
c++·python·学习·算法
AI医影跨模态组学9 分钟前
Cancer Letters(IF=10.1)中山大学附属第六医院等团队:基于治疗前MRI影像的RCMIX模型预测MRI定义的cT4期直肠癌T分期下降
人工智能·机器学习·论文·医学·医学影像·影像组学
技术落地手记13 分钟前
把AI塞进测试环节,我踩出了一条能用的路
人工智能·测试
2303_8212873816 分钟前
如何清洗SQL输入数据_使用框架内置的ORM处理数据交互
jvm·数据库·python
go不是csgo19 分钟前
s01 搭建第一个对话智能体
服务器·网络·python·ai
xixixi7777719 分钟前
AI的“账号”与“钱包”:AWS与Circle同日出手,AI正从工具进化
人工智能·安全·ai·大模型·云计算·aws