Python训练第四十三天

DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
 
# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
 
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
 
# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
 
# 1. 数据预处理
# 训练集:使用多种数据增强方法提高模型泛化能力
train_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
 
# 测试集:仅进行必要的标准化,保持数据原始特性
test_transform = transforms.Compose([
    # 新增:调整图像大小为统一尺寸
    transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
 
# 定义数据集根目录
root = r'C:\Users\vijay\Desktop\1'
 
train_dataset = datasets.ImageFolder(
    root=root + '/train',  # 指向 train 子文件夹
    transform=train_transform
)
test_dataset = datasets.ImageFolder(
    root=root + '/test',  # 指向 test 子文件夹
    transform=test_transform
)
 
# 打印类别信息,确认数据加载正确
print(f"训练集类别: {train_dataset.classes}")
print(f"测试集类别: {test_dataset.classes}")
 
# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

@浙大疏锦行

相关推荐
Edward.W13 小时前
Python uv:新一代Python包管理工具,彻底改变开发体验
开发语言·python·uv
小熊officer13 小时前
Python字符串
开发语言·数据库·python
月疯13 小时前
各种信号的模拟(ECG信号、质谱图、EEG信号),方便U-net训练
开发语言·python
荒诞硬汉13 小时前
JavaBean相关补充
java·开发语言
提笔忘字的帝国13 小时前
【教程】macOS 如何完全卸载 Java 开发环境
java·开发语言·macos
flysh0514 小时前
C# 架构设计:接口 vs 抽象类的深度选型指南
开发语言·c#
2501_9418824814 小时前
从灰度发布到流量切分的互联网工程语法控制与多语言实现实践思路随笔分享
java·开发语言
bkspiderx14 小时前
C++中的volatile:从原理到实践的全面解析
开发语言·c++·volatile
小鸡吃米…14 小时前
机器学习中的回归分析
人工智能·python·机器学习·回归
沛沛老爹14 小时前
Java泛型擦除:原理、实践与应对策略
java·开发语言·人工智能·企业开发·发展趋势·技术原理