4.3.2 模型构建
上一节定义好已经了解了ResNet模型结构,本节直接使用飞桨高层API中的Resnet50进行图像分类实验。
In [7]
from paddle.vision.models import resnet50
model = resnet50()
W0714 20:32:55.131150 102 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0714 20:32:55.136173 102 device_context.cc:465] device: 0, cuDNN Version: 7.6.
4.3.3 损失函数
飞桨高层API中都为大家提供了实现好交叉熵损失函数,代码如下所示。
In [8]
import paddle.nn.functional as F
loss_fn = F.cross_entropy
4.3.4 模型训练
使用交叉熵损失函数,并用SGD作为优化器来训练ResNet网络。
In [9]
# -*- coding: utf-8 -*-
# LeNet 识别眼疾图片
import os
import random
import paddle
import numpy as np
class Runner(object):
def __init__(self, model, optimizer, loss_fn):
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
# 记录全局最优指标
self.best_acc = 0
# 定义训练过程
def train_pm(self, train_datadir, val_datadir, **kwargs):
print('start training ... ')
self.model.train()
num_epochs = kwargs.get('num_epochs', 0)
csv_file = kwargs.get('csv_file', 0)
save_path = kwargs.get("save_path", "/home/aistudio/output/")
# 定义数据读取器,训练数据读取器
train_loader = data_loader(train_datadir, batch_size=10, mode='train')
for epoch in range(num_epochs):
for batch_id, data in enumerate(train_loader()):
x_data, y_data = data
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
# 运行模型前向计算,得到预测值
logits = model(img)
avg_loss = self.loss_fn(logits, label)
if batch_id % 20 == 0:
print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))
# 反向传播,更新权重,清除梯度
avg_loss.backward()
self.optimizer.step()
self.optimizer.clear_grad()
acc = self.evaluate_pm(val_datadir, csv_file)
self.model.train()
if acc > self.best_acc:
self.save_model(save_path)
self.best_acc = acc
# 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
@paddle.no_grad()
def evaluate_pm(self, val_datadir, csv_file):
self.model.eval()
accuracies = []
losses = []
# 验证数据读取器
valid_loader = valid_data_loader(val_datadir, csv_file)
for batch_id, data in enumerate(valid_loader()):
x_data, y_data = data
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
# 运行模型前向计算,得到预测值
logits = self.model(img)
# 多分类,使用softmax计算预测概率
pred = F.softmax(logits)
loss = self.loss_fn(pred, label)
acc = paddle.metric.accuracy(pred, label)
accuracies.append(acc.numpy())
losses.append(loss.numpy())
print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))
return np.mean(accuracies)
# 模型评估阶段,使用'paddle.no_grad()'控制不计算和存储梯度
@paddle.no_grad()
def predict_pm(self, x, **kwargs):
# 将模型设置为评估模式
self.model.eval()
# 运行模型前向计算,得到预测值
logits = self.model(x)
return logits
def save_model(self, save_path):
paddle.save(self.model.state_dict(), save_path + 'palm.pdparams')
paddle.save(self.optimizer.state_dict(), save_path + 'palm.pdopt')
def load_model(self, model_path):
model_state_dict = paddle.load(model_path)
self.model.set_state_dict(model_state_dict)
实例化Runner类,并传入训练配置,代码实现如下:
In [12]
# 开启0号GPU训练
use_gpu = True
paddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')
# 定义优化器
# opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters(), weight_decay=0.001)
opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
runner = Runner(model, opt, loss_fn)
用Runner在训练集上训练5个epoch,并保存准确率最高的模型作为最佳模型。
In [13]
import os
# 数据集路径
DATADIR = '/home/aistudio/work/palm/PALM-Training400/PALM-Training400'
DATADIR2 = '/home/aistudio/work/palm/PALM-Validation400'
CSVFILE = '/home/aistudio/labels.csv'
# 设置迭代轮数
EPOCH_NUM = 5
# 模型保存路径
PATH='/home/aistudio/output/'
if not os.path.exists(PATH):
os.makedirs(PATH)
# 启动训练过程
runner.train_pm(DATADIR, DATADIR2,
num_epochs=EPOCH_NUM, csv_file=CSVFILE, save_path=PATH)
start training ...
epoch: 0, batch_id: 0, loss is: 0.3287
epoch: 0, batch_id: 20, loss is: 0.0716
[validation] accuracy/loss: 0.9625/5.9818
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/framework/io.py:729: UserWarning: The input state dict is empty, no need to save.
warnings.warn("The input state dict is empty, no need to save.")
epoch: 1, batch_id: 0, loss is: 0.1286
epoch: 1, batch_id: 20, loss is: 0.4718
[validation] accuracy/loss: 0.9650/5.9824
epoch: 2, batch_id: 0, loss is: 0.0892
epoch: 2, batch_id: 20, loss is: 0.0313
[validation] accuracy/loss: 0.9625/5.9801
epoch: 3, batch_id: 0, loss is: 0.1362
epoch: 3, batch_id: 20, loss is: 0.0569
[validation] accuracy/loss: 0.9625/5.9746
epoch: 4, batch_id: 0, loss is: 0.1036
epoch: 4, batch_id: 20, loss is: 0.0873
[validation] accuracy/loss: 0.9575/5.9856
通过运行结果可以发现,使用ResNet在眼疾筛查数据集iChallenge-PM上,经过5个epoch的训练,在验证集上的准确率可以达到96%左右。
4.3.5 模型评估
使用测试数据对在训练过程中保存的最佳模型进行评价,观察模型在评估集上的准确率。代码实现如下:
In [ ]
# 加载最优模型
runner.load_model('/home/aistudio/output/palm.pdparams')
# 模型评价
score = runner.evaluate_pm(DATADIR2, CSVFILE)
[validation] accuracy/loss: 0.9725/5.9591
4.3.6 模型预测
同样地,也可以使用保存好的模型,对测试集中的某一个数据进行模型预测,观察模型效果。代码实现如下:
In [18]
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import paddle
import paddle.nn.functional as F
%matplotlib inline
# 加载最优模型
runner.load_model('/home/aistudio/output/palm.pdparams')
# 获取测试集中第一条数据
DATADIRv2 = '/home/aistudio/work/palm/PALM-Validation400'
filelists = open('/home/aistudio/labels.csv').readlines()
# 可以通过修改filelists列表的数字获取其他测试图片,可取值1-400
line = filelists[1].strip().split(',')
name, label = line[1], int(line[2])
# 读取测试图片
img = cv2.imread(os.path.join(DATADIRv2, name))
# 测试图片预处理
trans_img = transform_img(img)
unsqueeze_img = paddle.unsqueeze(paddle.to_tensor(trans_img), axis=0)
# 模型预测
logits = runner.predict_pm(unsqueeze_img)
result=F.softmax(logits)
pred_class = paddle.argmax(result).numpy()
# 输出真实类别与预测类别
print("The true category is {} and the predicted category is {}".format(label, pred_class))
# 图片可视化
show_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.imshow(show_img)
plt.show()
The true category is 0 and the predicted category is [0]
<Figure size 432x288 with 1 Axes>
小结
在这一节里,我们通过ResNet模型实现眼疾识别,在验证集上的预测精度在95%左右,通过这个案例熟悉了基础的视觉任务构建流程。如果读者有兴趣的话,可以进一步调整学习率和训练轮数等超参数,观察是否能够得到更高的精度。
作业
本节通过调用飞桨高层API Resnet50模型from paddle.vision.models import resnet50
实现了眼疾识别。更换其他模型,看是否能得到更高的精度,