Python训练第四十三天

DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
 
# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
 
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
 
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
 
# 1. 数据预处理
# 训练集:使用多种数据增强方法提高模型泛化能力
train_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
 
# 测试集:仅进行必要的标准化,保持数据原始特性
test_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
 
# 定义数据集根目录
root = r'C:\Users\vijay\Desktop\1'
 
train_dataset = datasets.ImageFolder(
    root=root + '/train',  # 指向 train 子文件夹
    transform=train_transform
)
test_dataset = datasets.ImageFolder(
    root=root + '/test',  # 指向 test 子文件夹
    transform=test_transform
)
 
# 打印类别信息,确认数据加载正确
print(f"训练集类别: {train_dataset.classes}")
print(f"测试集类别: {test_dataset.classes}")
 
# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

@浙大疏锦行

相关推荐
^Rocky1 分钟前
JavaScript性能优化实战
开发语言·javascript·性能优化
心本无晴.13 分钟前
Python进程,线程
python·进程
ponnylv19 分钟前
深入剖析Spring Boot启动流程
java·开发语言·spring boot·spring
萧邀人25 分钟前
第一课、Cocos Creator 3.8 安装与配置
开发语言
Jayden_Ruan1 小时前
C++逆向输出一个字符串(三)
开发语言·c++·算法
不吃鱼的羊1 小时前
启动文件Startup_vle.c
c语言·开发语言
VBA63372 小时前
VBA之Word应用第四章第二节:段落集合Paragraphs对象(二)
开发语言
点云SLAM2 小时前
C++ 常见面试题汇总
java·开发语言·c++·算法·面试·内存管理
java1234_小锋3 小时前
Scikit-learn Python机器学习 - 特征降维 压缩数据 - 特征提取 - 主成分分析 (PCA)
python·机器学习·scikit-learn
xiaowu0803 小时前
策略模式-不同的鸭子的案例
开发语言·c#·策略模式