百度飞浆ResNet50大模型微调实现十二种猫图像分类

12种猫分类比赛传送门

要求很简单,给train和test集,训练模型实现图像分类。

这里使用的是残差连接模型,这个平台有预训练好的模型,可以直接拿来主义。

训练十几个迭代,每个批次60左右,准确率达到90%以上

一、导入库,解压文件

python 复制代码
import os
import zipfile
import random
import json
import cv2
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import paddle
import paddle.nn as nn
from paddle.io import Dataset,DataLoader
from paddle.nn import \
                    Layer, \
                    Conv2D, Linear, \
                    Embedding, MaxPool2D, \
                    BatchNorm2D, ReLU
                    
import paddle.vision.transforms as transforms
from paddle.vision.models import resnet50
from paddle.metric import Accuracy

train_parameters = {
    "input_size": [3, 224, 224],                     # 输入图片的shape
    "class_dim": 12,                                 # 分类数
    "src_path":"data/data10954/cat_12_train.zip",   # 原始数据集路径
    "src_test_path":"data/data10954/cat_12_test.zip",   # 原始数据集路径
    "target_path":"/home/aistudio/data/dataset",     # 要解压的路径 
    "train_list_path": "./train.txt",                # train_data.txt路径
    "eval_list_path": "./eval.txt",                  # eval_data.txt路径
    "label_dict":{},                                 # 标签字典
    "readme_path": "/home/aistudio/data/readme.json",# readme.json路径
    "num_epochs":6,                                 # 训练轮数
    "train_batch_size": 16,                          # 批次的大小
    "learning_strategy": {                           # 优化函数相关的配置
        "lr": 0.0005                                  # 超参数学习率
    } 
}


scr_path=train_parameters['src_path']
target_path=train_parameters['target_path']
src_test_path=train_parameters["src_test_path"]
z = zipfile.ZipFile(scr_path, 'r')
z.extractall(path=target_path)
z = zipfile.ZipFile(src_test_path, 'r')
z.extractall(path=target_path)
z.close()
for imgpath in os.listdir(target_path + '/cat_12_train'):
    src = os.path.join(target_path + '/cat_12_train/', imgpath)
    img = Image.open(src)
    if img.mode != 'RGB':
        img = img.convert('RGB')
        img.save(src)

for imgpath in os.listdir(target_path + '/cat_12_test'):
    src = os.path.join(target_path + '/cat_12_test/', imgpath)
    img = Image.open(src)
    if img.mode != 'RGB':
        img = img.convert('RGB')
        img.save(src)

解压后将所有图像变为RGB图像

二、加载训练集,进行预处理、数据增强、格式变换

python 复制代码
transform = transforms.Compose([
    transforms.Resize(size=224),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

x_train,x_eval,y_train=[],[],[]#获取训练图像和标签、测试图像和标签
contents=[]
with open('data/data10954/train_list.txt')as f:
    contents=f.read().split('\n')

for item in contents:
    if item=='':
        continue
    path='data/dataset/'+item.split('\t')[0]
    data=np.array(Image.open(path).convert('RGB'))


    data=np.array(transform(data))
    x_train.append(data)
    y_train.append(int(item.split('\t')[-1]))

contetns=os.listdir('data/dataset/cat_12_test')
for item in contetns:
    path='data/dataset/cat_12_test/'+item
    data=np.array(Image.open(path).convert('RGB'))
    data=np.array(transform(data))

    x_eval.append(data)

重点是transforms变换的预处理

三、划分训练集和测试集

python 复制代码
x_train=np.array(x_train)

y_train=np.array(y_train)

x_eval=np.array(x_eval)



x_train,x_test,y_train,y_test=train_test_split(x_train,y_train,test_size=0.2,random_state=42,stratify=y_train)

x_train=paddle.to_tensor(x_train,dtype='float32')
y_train=paddle.to_tensor(y_train,dtype='int64')
x_test=paddle.to_tensor(x_test,dtype='float32')
y_test=paddle.to_tensor(y_test,dtype='int64')
x_eval=paddle.to_tensor(x_eval,dtype='float32')

这是必要的,可以随时利用测试集查看准确率

四、加载预训练模型,选择损失函数和优化器

python 复制代码
learning_rate=0.001
epochs =5  # 迭代轮数
batch_size = 50  # 批次大小
weight_decay=1e-5
num_class=12

cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')

for param in cnn.parameters():
    param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])
criterion=nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=learning_rate, parameters=cnn.fc.parameters(),weight_decay=weight_decay)

第一次训练把加载模型注释掉即可,优化器包含最后一层全连接的参数

五、模型训练

python 复制代码
if x_train.shape[3]==3:
    x_train=paddle.transpose(x_train,perm=(0,3,1,2))

dataset = paddle.io.TensorDataset([x_train, y_train])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):

    for batch_data, batch_labels in data_loader:
        outputs = cnn(batch_data)
        loss = criterion(outputs, batch_labels)
        print(epoch)
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy()[0]}")#保存参数
paddle.save({
    'cnn_state_dict': cnn.state_dict(),

}, 'checkpoint.pdparams')

使用批处理,这个很重要,不然平台分分钟炸了

六、测试集准确率

python 复制代码
num_class=12
batch_size=64
cnn=resnet50(pretrained=True)
checkpoint=paddle.load('checkpoint.pdparams')

for param in cnn.parameters():
    param.requires_grad=False
cnn.fc = nn.Linear(2048, num_class)
cnn.set_dict(checkpoint['cnn_state_dict'])

cnn.eval()

if x_test.shape[3]==3:
        x_test=paddle.transpose(x_test,perm=(0,3,1,2))
dataset = paddle.io.TensorDataset([x_test, y_test])
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

with paddle.no_grad():
    score=0
    for batch_data, batch_labels in data_loader:
    
        predictions = cnn(batch_data)
        predicted_probabilities = paddle.nn.functional.softmax(predictions, axis=1)
        predicted_labels = paddle.argmax(predicted_probabilities, axis=1) 
        print(predicted_labels)
        
        for i in range(len(predicted_labels)):
            if predicted_labels[i].numpy()==batch_labels[i]:
                score+=1
    print(score/len(y_test))

设置eval模式,使用批处理测试准确率

相关推荐
涛涛讲AI1 小时前
扣子平台音频功能:让声音也能“智能”起来
人工智能·音视频·工作流·智能体·ai智能体·ai应用
霍格沃兹测试开发学社测试人社区1 小时前
人工智能在音频、视觉、多模态领域的应用
软件测试·人工智能·测试开发·自动化·音视频
herosunly1 小时前
2024:人工智能大模型的璀璨年代
人工智能·大模型·年度总结·博客之星
PaLu-LI2 小时前
ORB-SLAM2源码学习:Initializer.cc(13): Initializer::ReconstructF用F矩阵恢复R,t及三维点
c++·人工智能·学习·线性代数·ubuntu·计算机视觉·矩阵
呆呆珝2 小时前
RKNN_C++版本-YOLOV5
c++·人工智能·嵌入式硬件·yolo
笔触狂放2 小时前
第一章 语音识别概述
人工智能·python·机器学习·语音识别
ZzYH222 小时前
文献阅读 250125-Accurate predictions on small data with a tabular foundation model
人工智能·笔记·深度学习·机器学习
格林威2 小时前
BroadCom-RDMA博通网卡如何进行驱动安装和设置使得对应网口具有RDMA功能以适配RDMA相机
人工智能·数码相机·opencv·计算机视觉·c#
迪小莫学AI2 小时前
【力扣每日一题】LeetCode 2412: 完成所有交易的初始最少钱数
算法·leetcode·职场和发展
程序员阿龙2 小时前
【精选】基于数据挖掘的招聘信息分析与市场需求预测系统 职位分析、求职者趋势分析 职位匹配、人才趋势、市场需求分析数据挖掘技术 职位需求分析、人才市场趋势预测
人工智能·数据挖掘·数据分析与可视化·数据挖掘技术·人才市场预测·招聘信息分析·在线招聘平台