网络模型训练完整代码

存个代码

具体看这位博主的网络模型训练完整套路 写的比较清晰

python 复制代码
import torchvision, torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
 
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
 
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
 
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 这一块替换为要训练的网络模型
''' 
class Mydata(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )
 
    def forward(self, x):
        x = self.model(x)
        return x
'''
mydata = Mydata()
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(mydata.parameters(), lr=learning_rate)  
 
total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("logs")
 
for i in range(epoch):
    print("------------第 {} 轮训练开始------------".format(i + 1))
 
    mydata.train()  
    for data in train_dataloader:
        imgs, targets = data
        outputs = mydata(imgs)
        loss = loss_fn(outputs, targets)
 
        optimizer.zero_grad()  
        loss.backward()
        optimizer.step()
 
        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))  
            writer.add_scalar("train_loss", loss.item(), total_train_step)  
 
    mydata.eval()  
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():  
        for data in test_dataloader:  
            imgs, targets = data
            outputs = mydata(imgs)  
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()  
            accuracy = (outputs.argmax(1) == targets).sum()  
            total_accuracy = total_accuracy + accuracy
 
    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率: {}".format(total_accuracy / test_data_size))  
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step)
    total_test_step = total_test_step + 1
 
    torch.save(mydata, "mydata_{}.pth".format(i)) 
    print("模型已保存")
 
writer.close()
相关推荐
机器之心14 小时前
英伟达发射了首个太空AI服务器,H100已上天
人工智能·openai
西柚小萌新14 小时前
【深入浅出PyTorch】--8.1.PyTorch生态--torchvision
人工智能·pytorch·python
m0_6501082415 小时前
【论文精读】迈向更好的指标:从T2VScore看文本到视频生成的新评测范式
人工智能·论文精读·评估指标·文本到视频生成·t2vscore·tvge数据集·视频质量评估
十子木15 小时前
C++ 类似pytorch的库,工具包,或者机器学习的生态
c++·pytorch·机器学习
算家计算15 小时前
一张白纸,无限画布:SkyReels刚刚重新定义了AI视频创作
人工智能·aigc·资讯
Kandiy1802539818715 小时前
PHY6252国产蓝牙低成本透传芯片BLE5.2智能灯控智能家居
人工智能·物联网·智能家居·射频工程
机器之心15 小时前
字节Seed团队发布循环语言模型Ouro,在预训练阶段直接「思考」,Bengio组参与
人工智能·openai
新智元16 小时前
AI 教父 Hinton 末日警告!你必须失业,AI 万亿泡沫豪赌才能「赢」
人工智能·openai
新智元16 小时前
CUDA 再见了!寒武纪亮出软件全家桶
人工智能·openai
oe101916 小时前
好文与笔记分享 A Survey of Context Engineering for Large Language Models(下)
人工智能·笔记·语言模型·agent