Python训练第四十三天

DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
 
# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
 
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
 
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
 
# 1. 数据预处理
# 训练集:使用多种数据增强方法提高模型泛化能力
train_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
 
# 测试集:仅进行必要的标准化,保持数据原始特性
test_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
 
# 定义数据集根目录
root = r'C:\Users\vijay\Desktop\1'
 
train_dataset = datasets.ImageFolder(
    root=root + '/train',  # 指向 train 子文件夹
    transform=train_transform
)
test_dataset = datasets.ImageFolder(
    root=root + '/test',  # 指向 test 子文件夹
    transform=test_transform
)
 
# 打印类别信息,确认数据加载正确
print(f"训练集类别: {train_dataset.classes}")
print(f"测试集类别: {test_dataset.classes}")
 
# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

@浙大疏锦行

相关推荐
jieyu111910 分钟前
Python 实战:Web 漏洞 Python POC 代码及原理详解(1)
python·web安全
AnalogElectronic33 分钟前
vue3 实现贪吃蛇手机版01
开发语言·javascript·ecmascript
QQLOVEYY39 分钟前
Python和PyCharm的安装教程
python·pycharm
Momentary_SixthSense1 小时前
rust笔记
开发语言·笔记·rust
多多*1 小时前
Spring Bean的生命周期 第二次思考
java·开发语言·rpc
大飞pkz1 小时前
【算法】排序算法汇总1
开发语言·数据结构·算法·c#·排序算法
想名字好难啊竟然不止我一个1 小时前
清除 Pip 缓存, 释放磁盘空间
python·缓存·pip
Eiceblue1 小时前
Python 快速提取扫描件 PDF 中的文本:OCR 实操教程
vscode·python·ocr·1024程序员节
APIshop1 小时前
淘宝/天猫 API 接口深度解析:商品详情获取与按图搜索商品(拍立淘)实战指南
python·1024程序员节
WangYan20222 小时前
ArcGIS Pro与Python下空间数据采集与管理——涵盖矢量、栅格、GPS、点云、多维数据与遥感云平台等
python·arcgis pro·空间数据采集与管理