学习基于pytorch的VGG图像分类 day4

注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主.

目录

VGG模型检测

一:导入必要的库和模块

二:主函数部分

1.调动cpu或者gpu

2.对图像的预处理

3.加载图像

4.运用预处理和扩展图像维度

5.读取json文件

6.创建模型和加载模型权重

[7. 对结果进行评估](#7. 对结果进行评估)

[8. 打印结果并显示图像](#8. 打印结果并显示图像)

9.运行主函数

小结


VGG模型检测

一:导入必要的库和模块

python 复制代码
# 导入所需的库和模块  
import os  # 导入操作系统相关的库  
import json  # 导入处理json数据的库  
  
import torch  # 导入PyTorch库  
from PIL import Image  # 导入处理图像数据的库  
from torchvision import transforms  # 导入PyTorch的图像预处理库  
import matplotlib.pyplot as plt  # 导入matplotlib库用于图像显示  
  
from model import vgg  # 从model模块中导入vgg模型  
  

二:主函数部分

1.调动cpu或者gpu
python 复制代码
    # 判断是否有GPU可用,并设置device变量  
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
2.对图像的预处理
python 复制代码
    # 定义图像预处理流程  
    data_transform = transforms.Compose(  
        [transforms.Resize((224, 224)),  # 将图像尺寸调整为224x224  
         transforms.ToTensor(),  # 将图像转换为Tensor格式  
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 对图像进行标准化处理  
3.加载图像

这里我是用的是绝对路径,可以改成基于上级文件的路径。

python 复制代码
    # 加载图像  
    img_path = "F:/code/Python/pytorch/VGG_image_classifcation/tulip.jpg"  # 定义图像路径  
    assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)  # 断言图像文件存在  
    img = Image.open(img_path)  # 使用PIL库打开图像文件  
    plt.imshow(img)  # 使用matplotlib显示图像  
4.运用预处理和扩展图像维度
python 复制代码
    # 对图像进行预处理  
    img = data_transform(img)  # 应用预处理流程  
    # 扩展图像数据的batch维度  
    img = torch.unsqueeze(img, dim=0)  # 将图像数据扩展为batch维度为1的张量  
5.读取json文件
python 复制代码
    # 读取类别索引字典  
    json_path = './class_indices.json'  # 定义json文件路径  
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)  # 断言json文件存在  
  
    with open(json_path, "r") as f:  # 打开json文件  
        class_indict = json.load(f)  # 读取json文件内容到class_indict变量中
6.创建模型和加载模型权重
python 复制代码
    # 创建模型  
    model = vgg(model_name="vgg16", num_classes=4).to(device)  # 创建vgg16模型,并指定输出类别数为4,然后移动到指定的设备上  
    # 加载模型权重  
    weights_path = "./vgg16Net.pth"  # 定义模型权重文件路径  
    assert os.path.exists(weights_path), "file: '{}' does not exist.".format(weights_path)  # 断言权重文件存在  
    model.load_state_dict(torch.load(weights_path, map_location=device))  # 加载模型权重
7. 对结果进行评估
python 复制代码
    model.eval()  # 将模型设置为评估模式  
    with torch.no_grad():  # 不计算梯度,节省计算资源  
        # 预测类别  
        output = torch.squeeze(model(img.to(device))).cpu()  # 对图像进行预测,并去除batch维度,然后将结果移动到CPU上  
        predict = torch.softmax(output, dim=0)  # 对预测结果进行softmax计算,得到每个类别的概率  
        predict_cla = torch.argmax(predict).numpy()  # 找到概率最大的类别的索引
8. 打印结果并显示图像
python 复制代码
    # 打印预测结果  
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],  
                                                 predict[predict_cla].numpy())  # 格式化预测结果  
    plt.title(print_res)  # 设置图像标题为预测结果  
    for i in range(len(predict)):  # 遍历每个类别的概率  
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],  
                                                  predict[i].numpy()))  # 打印每个类别的名称和概率  
    plt.show()  # 显示图像  
9.运行主函数
python 复制代码
    # 如果当前脚本被直接运行(而不是被其他脚本导入),则执行main函数  
if __name__ == '__main__':  
    main()

小结

1.记得导入VGG模型

2.结果进行可视化处理

相关推荐
2303_Alpha9 小时前
SpringBoot
笔记·学习
萘柰奈10 小时前
Unity学习----【进阶】TextMeshPro学习(三)--进阶知识点(TMP基础设置,材质球相关,两个辅助工具类)
学习·unity
沐矢羽10 小时前
Tomcat PUT方法任意写文件漏洞学习
学习·tomcat
好奇龙猫10 小时前
日语学习-日语知识点小记-进阶-JLPT-N1阶段蓝宝书,共120语法(10):91-100语法+考え方13
学习
向阳花开_miemie10 小时前
Android音频学习(十八)——混音流程
学习·音视频
工大一只猿11 小时前
51单片机学习
嵌入式硬件·学习·51单片机
c0d1ng11 小时前
量子计算学习(第十四周周报)
学习·量子计算
Hello_Embed18 小时前
STM32HAL 快速入门(二十):UART 中断改进 —— 环形缓冲区解决数据丢失
笔记·stm32·单片机·学习·嵌入式软件
咸甜适中18 小时前
rust语言 (1.88) 学习笔记:客户端和服务器端同在一个项目中
笔记·学习·rust
Magnetic_h19 小时前
【iOS】设计模式复习
笔记·学习·ios·设计模式·objective-c·cocoa