基于飞浆resnet50的102分类

目录

1.数据预处理

2.数据导入

3.模型导入

4.批训练

[5. 输出结果](#5. 输出结果)

6.结果参考


1.数据预处理

python 复制代码
T=transforms.Compose([
    transforms.Resize((250,250)),
    transforms.RandomCrop(size=224),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.46010968,0.4837371,0.49916607],std=[0.25398722,0.25408414,0.25931123])
])

2.数据导入

python 复制代码
datas=[]
labels=[]
train_path='data/data146107/dataset/train.txt'
eval_path='data/data146107/dataset/test.txt'
base='data/data146107/dataset/images/'
contents=[]
with open(train_path,mode='r',encoding='utf-8') as f:
    contents=f.read().split('\n')
for content in contents:
    if content=='':continue
    img=content.split('\t')[0]
    label=content.split('\t')[1]
    data=np.array(T(cv2.imread(base+img)))
    datas.append(data)
    labels.append(int(label))
datas=np.array(datas)
labels=np.array(labels)

3.模型导入

python 复制代码
model=resnet50(pretrained=True,num_classes=102)
criterion=paddle.nn.CrossEntropyLoss()
optimizer=paddle.optimizer.Adam(learning_rate=0.0001,parameters=model.parameters(),weight_decay=0.001)

4.批训练

python 复制代码
epochs=30
batch_size=125
dataset=TensorDataset([datas,labels])
dataloader=DataLoader(dataset,shuffle=True,batch_size=batch_size)
total_loss=[]
for epoch in range(epochs):
    for batch_data,batch_label in dataloader:
        batch_data=paddle.to_tensor(batch_data,dtype='float32')
        batch_label=paddle.to_tensor(batch_label,dtype='int64')
        output=model(batch_data)
        loss=criterion(output,batch_label)
        print(epoch,loss.numpy()[0])
        total_loss.append(loss.numpy()[0])
        optimizer.clear_grad()
        loss.backward()
        optimizer.step()
paddle.save({'model':model.state_dict(),'optimizer':optimizer.state_dict()},'checkpoint.param')
plt.plot(range(len(total_loss)),total_loss)
plt.show()

5. 输出结果

python 复制代码
contents=[]
batch_size=64
with open('data/data146107/dataset/test.txt',mode='r',encoding='utf-8') as f:
    contents=f.read().split('\n')
evals=[]
imgs=[]
base='data/data146107/dataset/images/'
for content in contents:
    if content=='':continue
    img=content
    data=np.array(T(cv2.imread(base+img)))
    evals.append(data)
    imgs.append(img)
evals=np.array(evals)
imgs=np.array(imgs)
dataset=TensorDataset([evals,imgs])
dataloader=DataLoader(dataset,shuffle=True,batch_size=batch_size)
with open('result.txt',mode='w',encoding='utf-8'):
    pass
with paddle.no_grad():
    for batch_data,batch_img in dataloader:
        batch_data=paddle.to_tensor(batch_data,dtype='float32')
        output=model(batch_data)
        output=np.array(paddle.argmax(output,axis=1))
        with open('result.txt',mode='a',encoding='utf-8') as f:
            for img,ans in zip(batch_img,output):
                f.write(img+'\t'+str(ans)+'\n')

6.结果参考

loss收敛到0.001 ,准确率到达93%左右

相关推荐
泰迪智能科技1 小时前
分享|职业技术培训|数字技术应用工程师快问快答
人工智能
Dxy12393102163 小时前
如何给AI提问:让机器高效理解你的需求
人工智能
少林码僧3 小时前
2.31 机器学习神器项目实战:如何在真实项目中应用XGBoost等算法
人工智能·python·算法·机器学习·ai·数据挖掘
钱彬 (Qian Bin)3 小时前
项目实践15—全球证件智能识别系统(切换为Qwen3-VL-8B-Instruct图文多模态大模型)
人工智能·算法·机器学习·多模态·全球证件识别
没学上了3 小时前
CNNMNIST
人工智能·深度学习
宝贝儿好3 小时前
【强化学习】第六章:无模型控制:在轨MC控制、在轨时序差分学习(Sarsa)、离轨学习(Q-learning)
人工智能·python·深度学习·学习·机器学习·机器人
智驱力人工智能4 小时前
守护流动的规则 基于视觉分析的穿越导流线区检测技术工程实践 交通路口导流区穿越实时预警技术 智慧交通部署指南
人工智能·opencv·安全·目标检测·计算机视觉·cnn·边缘计算
AI产品备案4 小时前
生成式人工智能大模型备案制度与发展要求
人工智能·深度学习·大模型备案·算法备案·大模型登记
AC赳赳老秦4 小时前
DeepSeek 私有化部署避坑指南:敏感数据本地化处理与合规性检测详解
大数据·开发语言·数据库·人工智能·自动化·php·deepseek
wm10434 小时前
机器学习之线性回归
人工智能·机器学习·线性回归