LeNet(pytorch实现

LeNet

本文编写了一个简单易懂的LeNet网络,并在F-MNIST数据集上进行测试,允许使用GPU计算

python 复制代码
在这里插入代码片
import torch
from torch import nn, optim 
import d2lzh_pytorch as d2l

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# f-mnist 数据集是28*28的
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 5),  # 输出通道,输出通道,核大小
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2),  # 高宽减半
            nn.Conv2d(6, 16, 5),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            d2l.FlattenLayer(),
            nn.Linear(16*4*4, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, img):
        feature = self.conv(img)
        output = self.fc(feature)
        return output
net = LeNet()

# 数据集
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 评估测试集,支持GPU
def evaluate_acc(data_iter, net, device = None):
    if device is None and isinstance(net, nn.Module):
        device = list(net.parameters())[0].device  # 看参数的gpu还是cpu
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X,y in data_iter:
            if isinstance(net, nn.Module):  # 这个可加可不加
                net.eval()  # 评估模式
                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train()  # 转回训练模式
                n += y.shape[0]
                
    return acc_sum / n

def train(net, train_iter, test_iter, optimizer, device, epochs):
    net = net.to(device)
    loss = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X,y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
        test_acc = evaluate_acc(test_iter, net)
        print('epoch %d, loss %.4f, train_acc %.4f, test_acc %.4f'%(epoch + 1, train_l_sum, train_acc_sum/n, test_acc))

lr, epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, train_iter, test_iter, optimizer, device, epochs)
相关推荐
VIP_CQCRE1 分钟前
Nano Banana API 来了:不到半价享官方同款品质,仅需约 ¥0.10/张!
人工智能
珠海西格7 分钟前
光伏电站全景感知体系:数据采集与设备状态监测技术
大数据·运维·服务器·数据库·人工智能
产品经理邹继强8 分钟前
VTC产品与创新篇④:产品战略全景图——从“造物者”到“生态设计师”
人工智能·产品经理
Deepoch9 分钟前
自然交互+精准感知!Deepoc具身模型开发板让清洁机器人告别“盲扫”
人工智能·科技·机器人·半导体·清洁机器人·具身模型·deepoc
yuezhilangniao11 分钟前
从对话大脑到万能助手:企业级AI助理五层AI架构实战指南-AI开发架构AI体系理性分层篇
人工智能·架构
玄同76522 分钟前
LangChain 1.0 模型接口:多厂商集成与统一调用
开发语言·人工智能·python·langchain·知识图谱·rag·智能体
acai_polo22 分钟前
如何在国内合规、稳定地使用GPT/Claude/Gemini API?中转服务全解析
人工智能·gpt·ai·语言模型·ai作画
北京青翼科技26 分钟前
【PCIe732】青翼PCIe采集卡-优质光纤卡- PCIe接口-万兆光纤卡
图像处理·人工智能·fpga开发·智能硬件·嵌入式实时数据库
喵手36 分钟前
Python爬虫实战:构建招聘会数据采集系统 - requests+lxml 实战企业名单爬取与智能分析!
爬虫·python·爬虫实战·requests·lxml·零基础python爬虫教学·招聘会数据采集
星幻元宇VR41 分钟前
5D动感影院,科技与沉浸式体验的完美融合
人工智能·科技·虚拟现实