CIFAR-10分类数据集预测模型搭建

最近一个偶然的机会,老师盯上了CIFAR-10数据集,既然如此,就借此锻炼一下自己代码能力吧

数据集介绍:

数据集组成:

CIFAR-10数据集由10个类别的 60000 张 32x32 彩色图像组成,每个类别 6000 张图像。有50000张训练图像和10000张测试图像。通过下载链接(www.cs.toronto.edu/~kriz/cifar...www.cs.toronto.edu/~kriz/cifar...)下载并解压可以看到如图1所示的文件目录。data_batch_1-5和test_batch分别代表了有五个训练批次和一个测试批次,每个批次有1万张图像,测试批次包含从每个类别中随机选择的 1000 张图像。训练批次以随机顺序包含剩余图像,但某些训练批次可能包含来自一个类别的图像多于另一类别的图像。其中,训练批次恰好包含每个类别的 5000 张图像。

图1 CIFAR-10文件目录图

数据集布局:

data:一个 10000x3072 的 uint8s numpy 数组。数组的每一行存储一个 32x32 彩色图像。前 1024 个条目包含红色通道值,接下来的 1024 个条目包含绿色通道值,最后 1024 个条目包含蓝色通道值。图像按行优先顺序存储,因此数组的前 32 个条目是图像第一行的红色通道值。

labels:0-9 范围内 10000 个数字的列表。索引 i 处的数字表示数组数据中第 i 个图像的标签。

batches.meta: 包含一个 10 元素列表,为上述标签数组中的数字标签提供有意义的名称。例如,label_names[0] =="飞机",label_names[1] =="汽车"等。

注:使用贝叶斯超参数优化能有效地找到参数的良好设置。并且,上面的介绍全都没用,直接调用dataloader便能满足你数据读取的需求!

模型搭建:

step1 数据读取:

ini 复制代码
transform_train = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomGrayscale(),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_test = T.Compose([     
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
​
# 下载并加载CIFAR-10训练集
traindata = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=64,
                                          shuffle=True, num_workers=0)#num_workers告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关,和GPU无关)
​
# 下载并加载CIFAR-10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=0)
​
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  1. 定义训练转换(transform_train)

    • T.RandomHorizontalFlip():随机水平翻转图像。
    • T.RandomGrayscale():随机将图像转换为灰度图。
    • T.ToTensor():将PIL图像或Numpy数组转换为torch.FloatTensor类型,并除以255.0来将数值缩放到[0, 1]区间。
    • T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):标准化图像,使每个通道的均值为0.5,标准差为0.5。
  2. 定义测试转换(transform_test)

  3. 加载CIFAR-10训练集

    • 使用torchvision.datasets.CIFAR10函数下载并加载CIFAR-10训练集。
    • root='./data':指定数据存储的根目录。
    • train=True:指定加载训练集。
    • download=True:如果数据集不在指定的root目录中,将自动下载。
    • transform=transform_train:应用训练时定义的转换。
  4. 创建训练数据加载器(trainloader)

    • torch.utils.data.DataLoader:创建一个数据加载器,用于批量加载数据。
    • batch_size=64:每个批次的样本数量。
    • shuffle=True:在每个epoch开始时是否打乱数据。
    • num_workers=0:加载数据时使用的子进程数量,0表示不使用多进程加载。
  5. 加载CIFAR-10测试集

    • 与加载训练集类似,但使用train=False来加载测试集,并应用transform_test转换。
  6. 创建测试数据加载器(testloader)

    • 参数设置与训练数据加载器类似,shuffle=False表示测试集不需要打乱。
  7. 定义类别(classes)

    • 定义CIFAR-10数据集中的10个类别名称,这些名称将用于后续的分类任务。

step2 为了训练数据集准备模型:

ini 复制代码
device = torch.device(
    'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.cuda.empty_cache()#没有gpu注释掉,为了清除cuda缓存使用
​
model = torchvision.models.resnet50()
pretrained_dict = model.state_dict()
num_classes = len(classes)
​
# 定义优化器
# optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), 0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)  # 学习率调度器
criterion = nn.CrossEntropyLoss().cuda()
​
model.to(device)
​
epochs = 1
​
  1. 设置设备

    • device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'):这一行检查是否有可用的GPU,如果有,则使用CUDA,否则回退到使用CPU。
  2. 清空CUDA缓存

    • torch.cuda.empty_cache():清空CUDA缓存,释放未使用的内存。
  3. 加载预训练模型

    • model = torchvision.models.resnet50():加载预训练的ResNet-50模型。
  4. 获取模型参数

    • pretrained_dict = model.state_dict():获取模型的参数字典。
  5. 设置类别数

    • num_classes = len(classes):根据之前定义的类别列表classes的长度设置类别的数量。
  6. 定义优化器

    • optimizer = torch.optim.Adam(model.parameters(), 0.01):使用Adam优化器,学习率设置为0.01。
  7. 设置学习率调度器

    • scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85):定义了一个学习率调度器,每当经过4个epoch,学习率会乘以0.85。
  8. 设置损失函数并移动到设备

    • criterion = nn.CrossEntropyLoss().cuda():使用交叉熵损失函数,并将其移动到GPU上(如果可用)。
  9. 将模型移动到设备

    • model.to(device):将模型移动到之前定义的设备上。
  10. 设置训练周期

    • epochs = 1:设置训练周期为1个epoch。

step3 数据集预测并把结果写入到.csv文件中:

ini 复制代码
import csv
correct,total = 0,0
predicted_list =[]
for j,data in enumerate(testloader):
    inputs,labels = data
    inputs,labels = inputs.to(device),labels.to(device)
    #前向传播
    outputs = model(inputs)
    _, predicted = torch.max(outputs.data,1)    
    total =total+labels.size(0)
    #print(predicted == labels)
    correct = correct +(predicted == labels).sum().item()
    predicted_list += predicted.cpu().tolist()
with open('output/data.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Predicted'])  # 写入表头
    for prediction in predicted_list:
        writer.writerow([prediction])  # 写入预测结果 
print('准确率:{:.4f}%'.format(100.0*correct/total))
  1. 初始化计数器和列表

    • correct, total = 0, 0:初始化正确预测数和总预测数为0。
    • predicted_list = []:初始化一个空列表,用于存储预测结果。
  2. 遍历测试数据加载器

    • for j, data in enumerate(testloader):遍历测试数据加载器testloader中的每个数据批次。
  3. 准备数据和标签

    • inputs, labels = data:获取批次中的数据和标签。
    • inputs, labels = inputs.to(device), labels.to(device):将数据和标签移动到之前定义的设备上。
  4. 前向传播

    • outputs = model(inputs):通过模型进行前向传播,获取模型的输出。
  5. 获取预测结果

    • _, predicted = torch.max(outputs.data, 1):从模型输出中获取预测结果,predicted是一个包含最大概率索引的张量。

到此,你就搭建好了一个简单的能分类CIFAR-10的模型,有木有一丝丝成就感呢

总结一下,这段代码及数据集很大一部分是为了部分同学因电脑性能较差而无法运行大量数据,转而使用小模型、小数据集学习使用。但是麻雀虽小,五脏俱全,你仍然可以从这学到很多关于模型训练的知识,在我的角度看,是比按教程学习如何使用别人现有的项目要好得多。另外,别忘了把数据做贝叶斯超参数优化看看能提升多少哦。

相关推荐
limingade2 小时前
手机实时提取SIM卡打电话的信令和声音-新的篇章(一、可行的方案探讨)
物联网·算法·智能手机·数据分析·信息与通信
AI大模型知识分享2 小时前
Prompt最佳实践|如何用参考文本让ChatGPT答案更精准?
人工智能·深度学习·机器学习·chatgpt·prompt·gpt-3
小言从不摸鱼4 小时前
【AI大模型】ChatGPT模型原理介绍(下)
人工智能·python·深度学习·机器学习·自然语言处理·chatgpt
jiao000015 小时前
数据结构——队列
c语言·数据结构·算法
迷迭所归处6 小时前
C++ —— 关于vector
开发语言·c++·算法
leon6256 小时前
优化算法(一)—遗传算法(Genetic Algorithm)附MATLAB程序
开发语言·算法·matlab
CV工程师小林6 小时前
【算法】BFS 系列之边权为 1 的最短路问题
数据结构·c++·算法·leetcode·宽度优先
Navigator_Z7 小时前
数据结构C //线性表(链表)ADT结构及相关函数
c语言·数据结构·算法·链表
Aic山鱼7 小时前
【如何高效学习数据结构:构建编程的坚实基石】
数据结构·学习·算法
天玑y7 小时前
算法设计与分析(背包问题
c++·经验分享·笔记·学习·算法·leetcode·蓝桥杯