神经网络——CIFAR10小实战

1.引子


Sequential的使用:将网络结构放入其中即可,可以简化代码。

找了一个对CIFAR10进行分类的模型。

2.代码实战

python 复制代码
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = Conv2d(3, 32, 5, padding=2)
        self.maxpool1 = MaxPool2d(2)
        self.conv2 = Conv2d(32, 32, 5, padding=2)
        self.maxpool2 = MaxPool2d(2)
        self.conv3 = Conv2d(32, 64, 5, padding=2)
        self.maxpool3 = MaxPool2d(2)
        self.flatten = Flatten()
        self.linear1 = Linear(1024, 64)
        self.linear2 = Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

tudui=Tudui()
print(tudui)

nn.Flatten()和torch.flatten()有相同的效果。

3.Sequential

python 复制代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self, x):
        x=self.model1(x)
        return x

tudui=Tudui()
print(tudui)
## 创建一个指定形状的 ones 张量
input=torch.ones((64,3,32,32))
output=tudui(input)
print(output.shape)

使用Sequential可以很大程度地简化代码。

4.利用TensorBoard进行数据可视化

使用SummaryWriter的add_graph()方法进行数据可视化。

python 复制代码
writer=SummaryWriter("logs_sqe")
writer.add_graph(tudui,input)
writer.close()

基本的网络搭建到此结束。

相关推荐
woshihonghonga6 分钟前
Deepseek在它擅长的AI数据处理领域还有是有低级错误【k折交叉验证中每折样本数计算】
人工智能·python·深度学习·机器学习
乌恩大侠8 分钟前
以 NVIDIA Sionna Research Kit 赋能 AI 原生 6G 科研
人工智能·usrp
三掌柜66621 分钟前
借助 Kiro:实现《晚间手机免打扰》应用,破解深夜刷屏困境
人工智能·aws
飞雁科技21 分钟前
CRM客户管理系统定制开发:如何精准满足企业需求并提升效率?
大数据·运维·人工智能·devops·驻场开发
飞雁科技24 分钟前
上位机软件定制开发技巧:如何打造专属工业解决方案?
大数据·人工智能·软件开发·devops·驻场开发
这张生成的图像能检测吗35 分钟前
SAMWISE:为文本驱动的视频分割注入SAM2的智慧
人工智能·图像分割·视频·时序
哥布林学者42 分钟前
吴恩达深度学习课程二: 改善深层神经网络 第一周:深度学习的实践 课后作业和代码实践
深度学习·ai
antonytyler1 小时前
机器学习实践项目(二)- 房价预测增强篇 - 特征工程一
人工智能·机器学习
N 年 后1 小时前
cursor和传统idea的区别是什么?
java·人工智能·intellij-idea
AI Echoes1 小时前
LangChain 使用语义路由选择不同的Prompt模板
人工智能·python·langchain·prompt·agent