python第51天

1.读取数据

使用CIFAR-10图像数据

python 复制代码
import torch
from torchvision import datasets, transforms

# 定义图像预处理流程
image_transform = transforms.Compose([
    transforms.ToTensor(),  # 将PIL图像转换为张量
    transforms.Normalize(mean=(0.5, 0.5, 0.5),  # RGB三通道均值
                         std=(0.5, 0.5, 0.5))   # RGB三通道标准差
])

# 获取训练数据集
trainset = datasets.CIFAR10(
    root='./data',  # 数据集存储路径
    train=True,     # 使用训练集
    transform=image_transform,
    download=True   # 如果本地不存在则下载
)

# 获取测试数据集
testset = datasets.CIFAR10(
    root='./data',
    train=False,    # 使用测试集
    transform=image_transform,
    download=True
)

# 配置数据加载器
train_loader = torch.utils.data.DataLoader(
    dataset=trainset,
    batch_size=128,    # 每批样本数量
    shuffle=True       # 训练时打乱顺序
)

test_loader = torch.utils.data.DataLoader(
    dataset=testset,
    batch_size=128,
    shuffle=False      # 测试时保持原始顺序
)

2.模型建立

(1)建立CNN模型

python 复制代码
import torch
import torch.nn as nn
 
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
        self.relu = nn.ReLU()
 
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # 16x16x16
        x = self.pool(self.relu(self.conv2(x)))  # 32x8x8
        x = x.view(-1, 32 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
 

@浙大疏锦行

相关推荐
kngines11 分钟前
【Node.js从 0 到 1:入门实战与项目驱动】1.1 什么是 Node.js?(定义、运行环境、与浏览器 JavaScript 的区别)
开发语言·javascript·node.js
java1234_小锋12 分钟前
【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 词云图-微博评论词云图实现
python·自然语言处理·flask·nlp·nlp舆情分析
codists21 分钟前
《AI-Assisted Programming》读后感
python
爱欲无极39 分钟前
基于Flask的微博话题多标签情感分析系统设计
后端·python·flask
F_D_Z1 小时前
conda issue
python·github·conda·issue
大阳1231 小时前
数据结构2.(双向链表,循环链表及内核链表)
c语言·开发语言·数据结构·学习·算法·链表·嵌入式
Wangsk1331 小时前
用 Python 批量处理 Excel:从重复值清洗到数据可视化
python·信息可视化·excel·pandas
ChipCamp1 小时前
Chisel芯片开发入门系列 -- 18. CPU芯片开发和解释8(流水线架构的代码级理解)
开发语言·青少年编程·fpga开发·scala·dsp开发·risc-v·chisel
越来越无动于衷2 小时前
智慧社区(八)——社区人脸识别出入管理系统设计与实现
java·开发语言·spring boot·python·mysql
正义的大古2 小时前
OpenLayers 详细开发指南 - 第八部分 - GeoJSON 转换与处理工具
开发语言·前端·javascript