文章目录
- 导入PyTorch并设置与设备无关的代码
- [1. 获取数据](#1. 获取数据)
- [2. 数据准备](#2. 数据准备)
- [3. 转换数据](#3. 转换数据)
- [4.1 加载数据 ImageFolder](#4.1 加载数据 ImageFolder)
- [4.2 加载数据 自定义Dataset](#4.2 加载数据 自定义Dataset)
- [6. 数据增强](#6. 数据增强)
- [7. 定义简单模型](#7. 定义简单模型)
- [8. 理想的损失曲线](#8. 理想的损失曲线)
- [9. 使用数据增强功模型](#9. 使用数据增强功模型)
- [10. 比较模型结果](#10. 比较模型结果)
- [11. 对自定义图像进行预测](#11. 对自定义图像进行预测)
大致流程介绍:
【导入PyTorch并设置与设备无关的代码】
- 获取数据
- 数据准备:在任何新的机器学习问题开始时,了解您正在使用的数据至关重要。
- 转换数据:通常获得的数据不会 100% 准备好用于机器学习模型,在这里我们将了解可以采取的一些步骤来转换图像,以便它们准备好用于模型。
- 加载数据:使用
ImageFolder
加载数据or
使用自定义Dataset
加载图像数据 - 数据增强:其他形式的变换,数据增强是扩展训练数据多样性的常用技术。
- 定义模型,训练预测,探索损失曲线。
导入PyTorch并设置与设备无关的代码
python
import torch
from torch import nn
# Note: this notebook requires torch >= 1.10.0
torch.__version__
python
# Setup device-agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device
1. 获取数据
数据集采用Food-101 -- Mining Discriminative Components with Random Forests
也可以通过下面代码下载:【需要修改下载到的文件地址】
python
import requests
import zipfile
from pathlib import Path
# Setup path to data folder
data_path = Path("E:\PycharmProjects\python_study\pytorch\data")
image_path = data_path / "pizza_steak_sushi"
# If the image folder doesn't exist, download it and prepare it...
if image_path.is_dir():
print(f"{image_path} directory exists.")
else:
print(f"Did not find {image_path} directory, creating one...")
image_path.mkdir(parents=True, exist_ok=True)
# Download pizza, steak, sushi data
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
print("Downloading pizza, steak, sushi data...")
f.write(request.content)
# Unzip pizza, steak, sushi data
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
print("Unzipping pizza, steak, sushi data...")
zip_ref.extractall(image_path)
可以看得数据集包括训练集和测试集,且是一个多分类问题。
2. 数据准备
目标是采用数据集存储结构,将其转换为可用于 PyTorch 的数据集。
查数据目录中的内容,以遍历每个子目录并计算存在的文件数:
python
import os
def walk_through_dir(dir_path):
"""
Walks through dir_path returning its contents.
Args:
dir_path (str or pathlib.Path): target directory
Returns:
A print out of:
number of subdiretories in dir_path
number of images (files) in each subdirectory
name of each subdirectory
"""
for dirpath, dirnames, filenames in os.walk(dir_path):
print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")
python
walk_through_dir(image_path)
每个训练类有大约 75 个图像,每个测试类有 25 个图像。
设置我们的训练和测试路径:
python
# Setup train and testing paths
train_dir = image_path / "train"
test_dir = image_path / "test"
train_dir, test_dir
python
(WindowsPath('E:/PycharmProjects/python_study/pytorch/data/pizza_steak_sushi/train'),
WindowsPath('E:/PycharmProjects/python_study/pytorch/data/pizza_steak_sushi/test'))
- 可视化图像
- 使用
pathlib.Path.glob()
获取所有图像路径,以查找以.jpg
结尾的所有文件。 - 使用 Python 的
random.choice()
选择随机图像路径。 - 使用
pathlib.Path.parent.stem
获取图像类名称。 - 由于我们正在处理图像,因此我们将使用
PIL.Image.open()
打开随机图像路径(PIL 代表 Python 图像库)。 - 然后我们将显示图像并打印一些元数据。
python
import random
from PIL import Image
# Set seed
random.seed(42) # <- try changing this and see what happens
# 1. Get all image paths (* means "any combination")
image_path_list = list(image_path.glob("*/*/*.jpg"))
# 2. Get random image path
random_image_path = random.choice(image_path_list)
# 3. Get image class from path name (the image class is the name of the directory where the image is stored)
image_class = random_image_path.parent.stem
# 4. Open image
img = Image.open(random_image_path)
# 5. Print metadata
print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")
img
我们可以对 matplotlib.pyplot.imshow() 执行相同的操作,只不过我们必须首先将图像转换为 NumPy 数组。
python
import numpy as np
import matplotlib.pyplot as plt
# Turn the image into an array
img_as_array = np.asarray(img)
# Plot the image with matplotlib
plt.figure(figsize=(10, 7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, width, color_channels]")
plt.axis(False);
3. 转换数据
将图像数据加载到 PyTorch 中,需要:
- 将其转换为张量(图像的数字表示)。
- 将其转换为
torch.utils.data.Dataset
,然后转换为torch.utils.data.DataLoader
,我们简称为Dataset
和DataLoader
。
PyTorch 有几种不同类型的预构建数据集和数据集加载器:
Problem space | Pre-built Datasets and Functions |
---|---|
Vision | torchvision.datasets |
Audio | torchaudio.datasets |
Text | torchtext.datasets |
Recommendation system | torchrec.datasets |
由于我们正在处理视觉问题,因此我们将使用 torchvision.datasets
来获取数据加载功能,并使用 torchvision.transforms
来准备数据。
导入基础库:
python
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
- 转换数据
将上面的图片数据集转换为张量,使用 torchvision.transforms
模块。
torchvision.transforms
包含许多预先构建的方法,用于格式化图像、将它们转换为张量,甚至操纵它们以进行数据增强(更改数据以使模型更难学习的做法,我们稍后会看到)上)的目的。
编写一系列转换步骤:
- 使用
transforms.Resize()
调整图像大小(从大约 512x512 到 64x64)。 - 使用
transforms.RandomHorizontalFlip()
在水平方向上随机翻转图像(这可以被视为数据增强的一种形式,因为它会人为地改变我们的图像数据)。 - 使用
transforms.ToTensor()
将我们的图像从 PIL 图像转换为 PyTorch 张量。
可以使用 torchvision.transforms.Compose()
编译所有这些步骤:
python
# Write transform for image
data_transform = transforms.Compose([
# Resize the images to 64x64
transforms.Resize(size=(64, 64)),
# Flip the images randomly on the horizontal
transforms.RandomHorizontalFlip(p=0.5), # p = probability of flip, 0.5 = 50% chance
# Turn the image into a torch.Tensor
transforms.ToTensor() # this also converts all pixel values from 0 to 255 to be between 0.0 and 1.0
])
编写一个函数来在各种图像上尝试它们:
python
def plot_transformed_images(image_paths, transform, n=3, seed=42):
"""Plots a series of random images from image_paths.
Will open n image paths from image_paths, transform them
with transform and plot them side by side.
Args:
image_paths (list): List of target image paths.
transform (PyTorch Transforms): Transforms to apply to images.
n (int, optional): Number of images to plot. Defaults to 3.
seed (int, optional): Random seed for the random generator. Defaults to 42.
"""
random.seed(seed)
random_image_paths = random.sample(image_paths, k=n)
for image_path in random_image_paths:
with Image.open(image_path) as f:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(f)
ax[0].set_title(f"Original \nSize: {f.size}")
ax[0].axis("off")
# Transform and plot image
# Note: permute() will change shape of image to suit matplotlib
# (PyTorch default is [C, H, W] but Matplotlib is [H, W, C])
transformed_image = transform(f).permute(1, 2, 0)
ax[1].imshow(transformed_image)
ax[1].set_title(f"Transformed \nSize: {transformed_image.shape}")
ax[1].axis("off")
fig.suptitle(f"Class: {image_path.parent.stem}", fontsize=16)
python
plot_transformed_images(image_path_list,
transform=data_transform,
n=3)
4.1 加载数据 ImageFolder
由于我们的数据采用标准图像分类格式,因此我们可以使用类 torchvision.datasets.ImageFolder
。
可以向它传递目标图像目录的文件路径以及我们想要对图像执行的一系列转换。
在数据文件夹 train_dir
和 test_dir
上进行测试,传入 transform=data_transform
将图像转换为张量。
python
# Use ImageFolder to create dataset(s)
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir, # target folder of images
transform=data_transform, # transforms to perform on data (images)
target_transform=None) # transforms to perform on labels (if necessary)
test_data = datasets.ImageFolder(root=test_dir,
transform=data_transform)
print(f"Train data:\n{train_data}\nTest data:\n{test_data}")
到此,Dataset构建完成,通过检查 classes 和 class_to_idx 属性以及训练集和测试集的长度来检查它们。
查看图像和标签是怎么样的:
python
img, label = train_data[0][0], train_data[0][1]
print(f"Image tensor:\n{img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Label datatype: {type(label)}")
图像现在采用张量的形式(形状 [3, 64, 64]
),标签采用与特定类相关的整数形式(由 class_to_idx
属性引用) 。
使用 matplotlib 绘制单个图像张量:
重新排列其维度的顺序,现在图像尺寸采用格式 CHW (颜色通道、高度、宽度),但 matplotlib 更喜欢 HWC (高度、宽度、颜色通道)。
python
# Rearrange the order of dimensions
img_permute = img.permute(1, 2, 0)
# Print out different shapes (before and after permute)
print(f"Original shape: {img.shape} -> [color_channels, height, width]")
print(f"Image permute shape: {img_permute.shape} -> [height, width, color_channels]")
# Plot the image
plt.figure(figsize=(10, 7))
plt.imshow(img.permute(1, 2, 0))
plt.axis("off")
plt.title(class_names[label], fontsize=14);
- 将加载的图片
Dataset
转为DataLoader
使用torch.utils.data.DataLoader
来实现
python
# Turn train and test Datasets into DataLoaders
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset=train_data,
batch_size=1, # how many samples per batch?
num_workers=1, # how many subprocesses to use for data loading? (higher = more)
shuffle=True) # shuffle the data?
test_dataloader = DataLoader(dataset=test_data,
batch_size=1,
num_workers=1,
shuffle=False) # don't usually need to shuffle testing data
train_dataloader, test_dataloader
检查一下形状:
python
img, label = next(iter(train_dataloader))
# Batch size will now be 1, try changing the batch_size parameter above and see what happens
print(f"Image shape: {img.shape} -> [batch_size, color_channels, height, width]")
print(f"Label shape: {label.shape}")
python
Image shape: torch.Size([1, 3, 64, 64]) -> [batch_size, color_channels, height, width]
Label shape: torch.Size([1])
现在可以使用这些 DataLoader
进行训练和测试循环来训练模型。
4.2 加载数据 自定义Dataset
为了看到这一点的实际效果,让我们通过子类化 torch.utils.data.Dataset (PyTorch 中所有 Dataset 的基类)来复制 torchvision.datasets.ImageFolder()
。
- Python的 os 用于处理目录(我们的数据存储在目录中)。
- Python 的 pathlib 用于处理文件路径(我们的每个图像都有一个唯一的文件路径)。
- torch 适用于 PyTorch 的所有内容。
- PIL 的 Image 类用于加载图像。
- torch.utils.data.Dataset 子类化并创建我们自己的自定义 Dataset 。
- torchvision.transforms 将我们的图像转换为张量。
- Python typing 模块中的各种类型可将类型提示添加到我们的代码中。
python
import os
import pathlib
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List
- 创建辅助函数来获取类名
编写一个辅助函数,能够在给定目录路径的情况下创建类名列表和类名及其索引的字典。
- 使用
os.scandir()
遍历目标目录(理想情况下该目录是标准图像分类格式)获取类名。 - 如果未找到类名,则引发错误(如果发生这种情况,目录结构可能有问题)。
- 将类名转换为数字标签字典,每个类一个。
单看第一个步骤:
python
# Setup path for target directory
target_directory = train_dir
print(f"Target directory: {target_directory}")
# Get the class names from the target directory
class_names_found = sorted([entry.name for entry in list(os.scandir(image_path / "train"))])
print(f"Class names found: {class_names_found}")
python
Target directory: E:\PycharmProjects\python_study\pytorch\data\pizza_steak_sushi\train
Class names found: ['pizza', 'steak', 'sushi']
完整的函数:
python
# Make function to find classes in target directory
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folder names in a target directory.
Assumes target directory is in standard image classification format.
Args:
directory (str): target directory to load classnames from.
Returns:
Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
Example:
find_classes("food_images/train")
>>> (["class_1", "class_2"], {"class_1": 0, ...})
"""
# 1. Get the class names by scanning the target directory
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
# 2. Raise an error if class names not found
if not classes:
raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
# 3. Create a dictionary of index labels (computers prefer numerical rather than string labels)
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
python
find_classes(train_dir)
python
(['pizza', 'steak', 'sushi'], {'pizza': 0, 'steak': 1, 'sushi': 2})
- 创建自定义 Dataset 来复制 ImageFolder
1) 子类 torch.utils.data.Dataset
。
2) 使用 targ_dir
参数(目标数据目录)和 transform
参数初始化我们的子类(因此我们可以选择在需要时转换数据)。
3) 为 paths
(目标图像的路径)、 transform
(我们可能想要使用的转换,可以是 None )、 classes
和 class_to_idx
(来自我们的 find_classes
() 函数)。
4) 创建一个函数来从文件加载图像并返回它们,这可以使用 PIL
或 torchvision.io
(用于视觉数据的输入/输出)。
5) 覆盖 torch.utils.data.Dataset
的 len
方法以返回 Dataset
中的样本数,建议但不是必需的。这样您就可以调用 len(Dataset)
。
6) 覆盖 torch.utils.data.Dataset
的 getitem
方法以从 Dataset
返回单个样本,这是必需的。
python
# Write a custom dataset class (inherits from torch.utils.data.Dataset)
from torch.utils.data import Dataset
# 1. Subclass torch.utils.data.Dataset
class ImageFolderCustom(Dataset):
# 2. Initialize with a targ_dir and transform (optional) parameter
def __init__(self, targ_dir: str, transform=None) -> None:
# 3. Create class attributes
# Get all image paths
self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # note: you'd have to update this if you've got .png's or .jpeg's
# Setup transforms
self.transform = transform
# Create classes and class_to_idx attributes
self.classes, self.class_to_idx = find_classes(targ_dir)
# 4. Make function to load images
def load_image(self, index: int) -> Image.Image:
"Opens an image via a path and returns it."
image_path = self.paths[index]
return Image.open(image_path)
# 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
def __len__(self) -> int:
"Returns the total number of samples."
return len(self.paths)
# 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
"Returns one sample of data, data and label (X, y)."
img = self.load_image(index)
class_name = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpeg
class_idx = self.class_to_idx[class_name]
# Transform if necessary
if self.transform:
return self.transform(img), class_idx # return data, label (X, y)
else:
return img, class_idx # return data, label (X, y)
在测试ImageFolderCustom
类之前,创建一些转换来准备图像:
python
# Augment train data
train_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()
])
# Don't augment test data, only reshape
test_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
转化为Dataset
python
train_data_custom = ImageFolderCustom(targ_dir=train_dir,
transform=train_transforms)
test_data_custom = ImageFolderCustom(targ_dir=test_dir,
transform=test_transforms)
train_data_custom, test_data_custom
检测是否有效:
- 创建显示随机图像的函数
创建一个名为 display_random_images(
) 的辅助函数,它可以帮助我们可视化 Dataset
中的图像。
- (1)接受
Dataset
和许多其他参数,例如classes
(我们的目标类的名称)、要显示的图像数量 ( n ) 和随机种子。 - (2)为了防止显示失控,我们将 n 限制为 10 个图像。
- (3)设置可重现图的随机种子(如果设置了 seed )。
- (4)获取随机样本索引列表(我们可以使用 Python 的
random.sample()
来绘制)。 - (5)设置 matplotlib 绘图。
- (6)循环遍历步骤 4 中找到的随机样本索引,并用 matplotlib 绘制它们。
- (7)确保示例图像的形状为 HWC (高度、宽度、颜色通道),以便我们可以绘制它们。
python
# 1. Take in a Dataset as well as a list of class names
def display_random_images(dataset: torch.utils.data.dataset.Dataset,
classes: List[str] = None,
n: int = 10,
display_shape: bool = True,
seed: int = None):
# 2. Adjust display if n too high
if n > 10:
n = 10
display_shape = False
print(f"For display purposes, n shouldn't be larger than 10, setting to 10 and removing shape display.")
# 3. Set random seed
if seed:
random.seed(seed)
# 4. Get random sample indexes
random_samples_idx = random.sample(range(len(dataset)), k=n)
# 5. Setup plot
plt.figure(figsize=(16, 8))
# 6. Loop through samples and display random samples
for i, targ_sample in enumerate(random_samples_idx):
targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]
# 7. Adjust image tensor shape for plotting: [color_channels, height, width] -> [color_channels, height, width]
targ_image_adjust = targ_image.permute(1, 2, 0)
# Plot adjusted samples
plt.subplot(1, n, i+1)
plt.imshow(targ_image_adjust)
plt.axis("off")
if classes:
title = f"class: {classes[targ_label]}"
if display_shape:
title = title + f"\nshape: {targ_image_adjust.shape}"
plt.title(title)
测试:
- 将自定义加载的图像转换为 DataLoader
python
# Turn train and test custom Dataset's into DataLoader's
from torch.utils.data import DataLoader
train_dataloader_custom = DataLoader(dataset=train_data_custom, # use custom created train Dataset
batch_size=1, # how many samples per batch?
num_workers=0, # how many subprocesses to use for data loading? (higher = more)
shuffle=True) # shuffle the data?
test_dataloader_custom = DataLoader(dataset=test_data_custom, # use custom created test Dataset
batch_size=1,
num_workers=0,
shuffle=False) # don't usually need to shuffle testing data
train_dataloader_custom, test_dataloader_custom
python
# Get image and label from custom DataLoader
img_custom, label_custom = next(iter(train_dataloader_custom))
# Batch size will now be 1, try changing the batch_size parameter above and see what happens
print(f"Image shape: {img_custom.shape} -> [batch_size, color_channels, height, width]")
print(f"Label shape: {label_custom.shape}")
python
Image shape: torch.Size([1, 3, 64, 64]) -> [batch_size, color_channels, height, width]
Label shape: torch.Size([1])
6. 数据增强
数据增强:【通过人为增加训练集多样性的方式更改数据的过程】 裁剪它或随机擦除一部分或随机旋转它们。
机器学习就是利用随机性的力量,研究表明随机变换(如 transforms.RandAugment()
和 transforms.TrivialAugmentWide()
)通常比手工选择的变换表现更好。
TrivialAugment
是最近对各种 PyTorch 视觉模型进行最先进的训练升级时使用的成分之一。
transforms.TrivialAugmentWide()
中需要注意的主要参数是 num_magnitude_bins=31
它定义了将选择多少范围的强度值来应用某种变换, 0 是无范围, 31 是最大范围(最高强度的最高机会)。
可以将 transforms.TrivialAugmentWide()
合并到 transforms.Compose()
中:
python
from torchvision import transforms
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.TrivialAugmentWide(num_magnitude_bins=31), # how intense
transforms.ToTensor() # use ToTensor() last to get everything between 0 & 1
])
# Don't need to perform augmentation on the test data
test_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
通常不会在测试集上执行数据增强。数据增强的想法是人为地增加训练集的多样性,以更好地对测试集进行预测。
测试效果:
python
# Get all image paths
image_path_list = list(image_path.glob("*/*/*.jpg"))
# Plot random images
plot_transformed_images(
image_paths=image_path_list,
transform=train_transforms,
n=3,
seed=None
)
7. 定义简单模型
- 【为模型 创建转换并加载数据】没有数据增强,简单的数据转换:
python
# Create simple transform
simple_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
接下来:
- 加载数据,首先将每个训练和测试文件夹变成 Dataset 和 torchvision.datasets.ImageFolder()
- 然后使用 torch.utils.data.DataLoader() 进入 DataLoader 。
python
# 1. Load and transform data
from torchvision import datasets
train_data_simple = datasets.ImageFolder(root=train_dir, transform=simple_transform)
test_data_simple = datasets.ImageFolder(root=test_dir, transform=simple_transform)
# 2. Turn data into DataLoaders
import os
from torch.utils.data import DataLoader
# Setup batch size and number of workers
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")
# Create DataLoader's
train_dataloader_simple = DataLoader(train_data_simple,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS)
test_dataloader_simple = DataLoader(test_data_simple,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS)
train_dataloader_simple, test_dataloader_simple
- 创建TinyVGG模型
模型采用:https://poloclub.github.io/cnn-explainer/
python
class TinyVGG(nn.Module):
"""
Model architecture copying TinyVGG from:
https://poloclub.github.io/cnn-explainer/
"""
def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
super().__init__()
self.conv_block_1 = nn.Sequential(
nn.Conv2d(in_channels=input_shape,
out_channels=hidden_units,
kernel_size=3, # how big is the square that's going over the image?
stride=1, # default
padding=1), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number
nn.ReLU(),
nn.Conv2d(in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,
stride=2) # default stride value is same as kernel_size
)
self.conv_block_2 = nn.Sequential(
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
# Where did this in_features shape come from?
# It's because each layer of our network compresses and changes the shape of our inputs data.
nn.Linear(in_features=hidden_units*16*16,
out_features=output_shape)
)
def forward(self, x: torch.Tensor):
x = self.conv_block_1(x)
# print(x.shape)
x = self.conv_block_2(x)
# print(x.shape)
x = self.classifier(x)
# print(x.shape)
return x
# return self.classifier(self.conv_block_2(self.conv_block_1(x))) # <- leverage the benefits of operator fusion
torch.manual_seed(42)
model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB)
hidden_units=10,
output_shape=len(train_data.classes)).to(device)
model_0
- 测试模型:尝试对单个图像进行前向传递
- 从 DataLoader 获取一批图像和标签。
- 从批次中获取单个图像和 unsqueeze() 图像,使其批次大小为 1 (因此其形状适合模型)。
- 对单个图像执行推理(确保将图像发送到目标 device )。
- 打印出正在发生的情况,并使用 torch.softmax() 将模型的原始输出 logits 转换为预测概率(因为我们正在处理多类数据),并使用 torch.argmax() 将预测概率转换为预测标签。
python
# 1. Get a batch of images and labels from the DataLoader
img_batch, label_batch = next(iter(train_dataloader_simple))
# 2. Get a single image from the batch and unsqueeze the image so its shape fits the model
img_single, label_single = img_batch[0].unsqueeze(dim=0), label_batch[0]
print(f"Single image shape: {img_single.shape}\n")
# 3. Perform a forward pass on a single image
model_0.eval()
with torch.inference_mode():
pred = model_0(img_single.to(device))
# 4. Print out what's happening and convert model logits -> pred probs -> pred label
print(f"Output logits:\n{pred}\n")
print(f"Output prediction probabilities:\n{torch.softmax(pred, dim=1)}\n")
print(f"Output prediction label:\n{torch.argmax(torch.softmax(pred, dim=1), dim=1)}\n")
print(f"Actual label:\n{label_single}")
- 使用 torchinfo 了解模型中的形状【可跳过】
安装:
python
# Install torchinfo if it's not available, import it if it is
!pip install torchinfo
torchinfo
附带一个 summary
() 方法,该方法采用 PyTorch 模型以及 input_shape 并返回张量在模型中移动时发生的情况。
python
# Install torchinfo if it's not available, import it if it is
import torchinfo
from torchinfo import summary
summary(model_0, input_size=[1, 3, 64, 64]) # do a test pass through of an example input size
- 创建训练和测试循环函数
创建3个函数:
(1)train_step()
- 接受模型、 DataLoader
、损失函数和优化器,并在 DataLoader
上训练模型。
(2)test_step()
- 接受模型、 DataLoader
和损失函数,并在 DataLoader
上评估模型。
(3)train()
- 对给定数量的 epoch 一起执行 1. 和 2. 并返回结果字典。
python
def train_step(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer):
# Put model in train mode
model.train()
# Setup train loss and train accuracy values
train_loss, train_acc = 0, 0
# Loop through data loader data batches
for batch, (X, y) in enumerate(dataloader):
# Send data to target device
X, y = X.to(device), y.to(device)
# 1. Forward pass
y_pred = model(X)
# 2. Calculate and accumulate loss
loss = loss_fn(y_pred, y)
train_loss += loss.item()
# 3. Optimizer zero grad
optimizer.zero_grad()
# 4. Loss backward
loss.backward()
# 5. Optimizer step
optimizer.step()
# Calculate and accumulate accuracy metric across all batches
y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
train_acc += (y_pred_class == y).sum().item()/len(y_pred)
# Adjust metrics to get average loss and accuracy per batch
train_loss = train_loss / len(dataloader)
train_acc = train_acc / len(dataloader)
return train_loss, train_acc
python
def test_step(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module):
# Put model in eval mode
model.eval()
# Setup test loss and test accuracy values
test_loss, test_acc = 0, 0
# Turn on inference context manager
with torch.inference_mode():
# Loop through DataLoader batches
for batch, (X, y) in enumerate(dataloader):
# Send data to target device
X, y = X.to(device), y.to(device)
# 1. Forward pass
test_pred_logits = model(X)
# 2. Calculate and accumulate loss
loss = loss_fn(test_pred_logits, y)
test_loss += loss.item()
# Calculate and accumulate accuracy
test_pred_labels = test_pred_logits.argmax(dim=1)
test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
# Adjust metrics to get average loss and accuracy per batch
test_loss = test_loss / len(dataloader)
test_acc = test_acc / len(dataloader)
return test_loss, test_acc
- 创建一个
train()
函数来组合train_step()
和test_step()
该函数将训练模型并对其进行评估。
(1)采用模型、用于训练和测试集的 DataLoader 、优化器、损失函数以及执行每个训练和测试步骤的轮数。
(2)为 train_loss 、 train_acc 、 test_loss 和 test_acc 值创建一个空结果字典(我们可以在训练进行时填充它)。
(3)循环执行多个时期的训练和测试步骤函数。
(4)打印出每个时期结束时发生的情况。
(5)使用每个时期更新的指标更新空结果字典。
(7)return 填充的数据。
python
from tqdm.auto import tqdm
# 1. Take in various parameters required for training and test steps
def train(model: torch.nn.Module,
train_dataloader: torch.utils.data.DataLoader,
test_dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
epochs: int = 5):
# 2. Create empty results dictionary
results = {"train_loss": [],
"train_acc": [],
"test_loss": [],
"test_acc": []
}
# 3. Loop through training and testing steps for a number of epochs
for epoch in tqdm(range(epochs)):
train_loss, train_acc = train_step(model=model,
dataloader=train_dataloader,
loss_fn=loss_fn,
optimizer=optimizer)
test_loss, test_acc = test_step(model=model,
dataloader=test_dataloader,
loss_fn=loss_fn)
# 4. Print out what's happening
print(
f"Epoch: {epoch+1} | "
f"train_loss: {train_loss:.4f} | "
f"train_acc: {train_acc:.4f} | "
f"test_loss: {test_loss:.4f} | "
f"test_acc: {test_acc:.4f}"
)
# 5. Update results dictionary
results["train_loss"].append(train_loss)
results["train_acc"].append(train_acc)
results["test_loss"].append(test_loss)
results["test_acc"].append(test_acc)
# 6. Return the filled results at the end of the epochs
return results
- 训练和评估模型
将 TinyVGG 模型、 DataLoader 和 train() 函数放在一起,看看是否可以构建一个能够区分披萨、牛排和寿司的模型。
python
# Set random seeds
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Set number of epochs
NUM_EPOCHS = 5
# Recreate an instance of TinyVGG
model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB)
hidden_units=10,
output_shape=len(train_data.classes)).to(device)
# Setup loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_0.parameters(), lr=0.001)
# Start the timer
from timeit import default_timer as timer
start_time = timer()
# Train model_0
model_0_results = train(model=model_0,
train_dataloader=train_dataloader_simple,
test_dataloader=test_dataloader_simple,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=NUM_EPOCHS)
# End the timer and print out how long it took
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")
模型表现很差!
- 绘制模型0的损失曲线
python
# Check the model_0_results keys
model_0_results.keys()
python
dict_keys(['train_loss', 'train_acc', 'test_loss', 'test_acc'])
python
def plot_loss_curves(results: Dict[str, List[float]]):
"""Plots training curves of a results dictionary.
Args:
results (dict): dictionary containing list of values, e.g.
{"train_loss": [...],
"train_acc": [...],
"test_loss": [...],
"test_acc": [...]}
"""
# Get the loss values of the results dictionary (training and test)
loss = results['train_loss']
test_loss = results['test_loss']
# Get the accuracy values of the results dictionary (training and test)
accuracy = results['train_acc']
test_accuracy = results['test_acc']
# Figure out how many epochs there were
epochs = range(len(results['train_loss']))
# Setup a plot
plt.figure(figsize=(15, 7))
# Plot loss
plt.subplot(1, 2, 1)
plt.plot(epochs, loss, label='train_loss')
plt.plot(epochs, test_loss, label='test_loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.legend()
# Plot accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, accuracy, label='train_accuracy')
plt.plot(epochs, test_accuracy, label='test_accuracy')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.legend();
python
plot_loss_curves(model_0_results)
8. 理想的损失曲线
左:如果您的训练和测试损失曲线没有您想要的那么低,则这被认为是欠拟合。 中:当您的测试/验证损失高于训练损失时,这被认为是过度拟合。右图:理想的情况是训练和测试损失曲线随着时间的推移保持一致。这意味着您的模型具有良好的泛化能力。
- 处理过拟合
由于过度拟合的主要问题是您的模型对训练数据拟合得太好,因此您需要使用技术来"控制它"。
防止过度拟合的常用技术称为正则化。Regularization
方法:
防止过拟合的方法 | what |
---|---|
获取更多数据 | 拥有更多数据使模型有更多机会学习模式,这些模式可能更适用于新示例。 |
简化模型 | 如果当前模型已经过度拟合训练数据,则模型可能过于复杂。这意味着它对数据模式的学习太好了,无法很好地泛化到未见过的数据。简化模型的一种方法是减少其使用的层数或减少每层中隐藏单元的数量。 |
使用数据增强 | 数据增强以某种方式操纵训练数据,使模型更难学习,因为它人为地为数据添加了更多多样性。如果模型能够学习增强数据中的模式,则该模型可能能够更好地泛化到未见过的数据。 |
使用迁移学习 | 迁移学习涉及利用模型已学会的模式(也称为预训练权重)作为您自己的任务的基础。在我们的例子中,我们可以使用一种在多种图像上进行预训练的计算机视觉模型,然后稍微调整它以更加专门用于食品图像。 |
使用 dropout 层 | Dropout 层随机删除神经网络中隐藏层之间的连接,有效地简化了模型,同时也使剩余的连接变得更好。有关更多信息,请参阅 torch.nn.Dropout() 。 |
Use learning rate decay | 这里的想法是在模型训练时慢慢降低学习率。这类似于伸手去拿沙发后面的硬币。距离越近,脚步就越小。与学习率相同,越接近收敛,您希望权重更新越小。 |
使用提前停止 | 提前停止会在模型开始过度拟合之前停止训练。例如,假设模型的损失在过去 10 个时期内已停止减少(该数字是任意的),您可能希望在此处停止模型训练并使用损失最低的模型权重(之前的 10 个时期)。 |
- 处理欠拟合
当模型拟合不足时,它被认为对训练和测试集的预测能力很差。
从本质上讲,欠拟合模型将无法将损失值降低到所需的水平。
看看我们当前的损失曲线,我认为我们的 TinyVGG 模型 model_0 与数据拟合不足。
处理欠拟合的主要思想是提高模型的预测能力。
防止欠拟合的方法 | What |
---|---|
向模型添加更多层/单元 | 如果您的模型拟合不足,它可能没有足够的能力来学习预测所需的数据模式/权重/表示。为模型添加更多预测能力的一种方法是增加这些层中隐藏层/单元的数量。 |
调整学习率 | 也许你的模型的学习率一开始就太高了。而且它在每个时期都试图过多地更新权重,结果却没有学到任何东西。在这种情况下,您可以降低学习率并看看会发生什么。 |
使用迁移学习 | 迁移学习能够防止过度拟合和欠拟合。它涉及使用以前工作模型中的模式并根据您自己的问题进行调整。 |
训练时间加长 | 有时模型只是需要更多时间来学习数据的表示。如果您发现在较小的实验中您的模型没有学到任何东西,也许让它训练更多的时期可能会带来更好的性能。 |
使用较少的正则化 | 也许您的模型拟合不足,因为您试图防止过度拟合。抑制正则化技术可以帮助您的模型更好地拟合数据。 |
- 过拟合和欠拟合之间的平衡
当涉及到处理自身问题的过度拟合和欠拟合时,迁移学习可能是最强大的技术之一。
迁移学习不是手动设计不同的过拟合和欠拟合技术,而是使您能够在与您的问题空间类似的问题空间中采用已经工作的模型并将其应用到您自己的数据集。
9. 使用数据增强功模型
编写一个训练转换以包含 transforms.TrivialAugmentWide() 以及调整图像大小并将图像转换为张量。
- 使用数据增强创建转换
python
# Create training transform with TrivialAugment
train_transform_trivial_augment = transforms.Compose([
transforms.Resize((64, 64)),
transforms.TrivialAugmentWide(num_magnitude_bins=31),
transforms.ToTensor()
])
# Create testing transform (no data augmentation)
test_transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor()
])
- 创建训练和测试 Dataset 和 DataLoader
python
# Turn image folders into Datasets
train_data_augmented = datasets.ImageFolder(train_dir, transform=train_transform_trivial_augment)
test_data_simple = datasets.ImageFolder(test_dir, transform=test_transform)
train_data_augmented, test_data_simple
python
# Turn Datasets into DataLoader's
import os
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
torch.manual_seed(42)
train_dataloader_augmented = DataLoader(train_data_augmented,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS)
test_dataloader_simple = DataLoader(test_data_simple,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS)
train_dataloader_augmented, test_dataloader
- 构建和训练模型
python
# Create model_1 and send it to the target device
torch.manual_seed(42)
model_1 = TinyVGG(
input_shape=3,
hidden_units=10,
output_shape=len(train_data_augmented.classes)).to(device)
model_1
使用与 model_0 相同的设置,仅 train_dataloader 参数发生变化:
python
# Set random seeds
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# Set number of epochs
NUM_EPOCHS = 5
# Setup loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_1.parameters(), lr=0.001)
# Start the timer
from timeit import default_timer as timer
start_time = timer()
# Train model_1
model_1_results = train(model=model_1,
train_dataloader=train_dataloader_augmented,
test_dataloader=test_dataloader_simple,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=NUM_EPOCHS)
# End the timer and print out how long it took
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")
模型表现也不是很好。
- 绘制模型的损失曲线
python
plot_loss_curves(model_1_results)
10. 比较模型结果
尽管模型都表现很差,仍然可以编写代码来比较它们。
将模型结果转换为 pandas DataFrames:
python
import pandas as pd
model_0_df = pd.DataFrame(model_0_results)
model_1_df = pd.DataFrame(model_1_results)
model_0_df
使用 matplotlib 编写一些绘图代码来一起可视化 model_0 和 model_1 的结果:
python
# Setup a plot
plt.figure(figsize=(15, 10))
# Get number of epochs
epochs = range(len(model_0_df))
# Plot train loss
plt.subplot(2, 2, 1)
plt.plot(epochs, model_0_df["train_loss"], label="Model 0")
plt.plot(epochs, model_1_df["train_loss"], label="Model 1")
plt.title("Train Loss")
plt.xlabel("Epochs")
plt.legend()
# Plot test loss
plt.subplot(2, 2, 2)
plt.plot(epochs, model_0_df["test_loss"], label="Model 0")
plt.plot(epochs, model_1_df["test_loss"], label="Model 1")
plt.title("Test Loss")
plt.xlabel("Epochs")
plt.legend()
# Plot train accuracy
plt.subplot(2, 2, 3)
plt.plot(epochs, model_0_df["train_acc"], label="Model 0")
plt.plot(epochs, model_1_df["train_acc"], label="Model 1")
plt.title("Train Accuracy")
plt.xlabel("Epochs")
plt.legend()
# Plot test accuracy
plt.subplot(2, 2, 4)
plt.plot(epochs, model_0_df["test_acc"], label="Model 0")
plt.plot(epochs, model_1_df["test_acc"], label="Model 1")
plt.title("Test Accuracy")
plt.xlabel("Epochs")
plt.legend();
看起来我们的模型表现同样糟糕并且有点零星(指标急剧上升和下降)。
11. 对自定义图像进行预测
找一张需要进行预测的图片,例如下图:
python
# Setup custom image path
custom_image_path = "E:\\PycharmProjects\\python_study\\pytorch\\data\\predicate\\04-pizza-dad.jpeg"
- 使用 PyTorch 加载自定义图像
torchvision.io
可以读取和写入图像和视频,要加载图像,因此我们将使用 torchvision.io.read_image()
。
此方法将读取 JPEG 或 PNG 图像,并将其转换为 3 维 RGB 或灰度 torch.Tensor ,数据类型 uint8 的值在 [0, 255] 范围内。
python
import torchvision
# Read in custom image
custom_image_uint8 = torchvision.io.read_image(str(custom_image_path))
# Print out image data
print(f"Custom image tensor:\n{custom_image_uint8}\n")
print(f"Custom image shape: {custom_image_uint8.shape}\n")
print(f"Custom image dtype: {custom_image_uint8.dtype}")
custom_image 张量的数据类型为 torch.uint8 ,其值在 [0, 255] 之间。
模型采用数据类型 torch.float32 且值在 [0, 1] 之间的图像张量。
需要将其转换为与模型训练数据相同的格式,不然会报错:
python
RuntimeError: Input type (torch.cuda.ByteTensor) and weight type (torch.cuda.FloatTensor) should be the same
将自定义图像转换为与模型训练时相同的数据类型 ( torch.float32 ):
python
# Load in custom image and convert the tensor values to float32
custom_image = torchvision.io.read_image(str(custom_image_path)).type(torch.float32)
# Divide the image pixel values by 255 to get them between [0, 1]
custom_image = custom_image / 255.
# Print out image data
print(f"Custom image tensor:\n{custom_image}\n")
print(f"Custom image shape: {custom_image.shape}\n")
print(f"Custom image dtype: {custom_image.dtype}")
- 使用经过训练的 PyTorch 模型预测自定义图像
我们的模型是在形状 [3, 64, 64] 的图像上进行训练的,而我们的自定义图像目前是 [3, 4032, 3024] 。
让我们用 matplotlib 绘制图像以确保它看起来不错,记住我们必须将尺寸从 CHW 排列为 HWC 以满足 matplotlib 的要求。
python
# Plot custom image
plt.imshow(custom_image.permute(1, 2, 0)) # need to permute image dimensions from CHW -> HWC otherwise matplotlib will error
plt.title(f"Image shape: {custom_image.shape}")
plt.axis(False);
一种方法是使用 torchvision.transforms.Resize() :
python
# Create transform pipleine to resize image
custom_image_transform = transforms.Compose([
transforms.Resize((64, 64)),
])
# Transform target image
custom_image_transformed = custom_image_transform(custom_image)
# Print out original shape and new shape
print(f"Original shape: {custom_image.shape}")
print(f"New shape: {custom_image_transformed.shape}")
python
Original shape: torch.Size([3, 4032, 3024])
New shape: torch.Size([3, 64, 64])
进行预测:
两个注意点:
(1)需要to(device),没有报错:RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper___slow_conv2d_forward)
(2)上面的数据结构是 CHW,模型需要的是NCHW,N是批量大小,没有修改结构会报错:RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x256 and 2560x3)
,可以使用 torch.unsqueeze(dim=0) 添加批量大小维度来为图像添加额外维度并最终进行预测
python
mmodel_1.eval()
with torch.inference_mode():
# Add an extra dimension to image
custom_image_transformed_with_batch_size = custom_image_transformed.unsqueeze(dim=0)
# Print out different shapes
print(f"Custom image transformed shape: {custom_image_transformed.shape}")
print(f"Unsqueezed custom image shape: {custom_image_transformed_with_batch_size.shape}")
# Make a prediction on image with an extra dimension
custom_image_pred = model_1(custom_image_transformed.unsqueeze(dim=0).to(device))
python
Custom image transformed shape: torch.Size([3, 64, 64])
Unsqueezed custom image shape: torch.Size([1, 3, 64, 64])
补充易错的3个点
(1)错误的数据类型 - 我们的模型需要 torch.float32 ,而我们的原始自定义图像是 uint8 。
(2)错误的设备 - 我们的模型位于目标 device (在我们的例子中为 GPU)上,而我们的目标数据尚未移动到目标 device 。
(3)错误的形状 - 我们的模型期望输入图像的形状为 [N, C, H, W] 或 [batch_size, color_channels, height, width] ,而我们的自定义图像张量的形状为 [color_channels, height, width] 。
打印logit 形式(模型的原始输出称为 logits):
python
custom_image_pred
python
tensor([[ 0.1159, 0.0208, -0.1422]], device='cuda:0')
从 logits -> 预测概率 -> 预测标签转换:
python
# Print out prediction logits
print(f"Prediction logits: {custom_image_pred}")
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
custom_image_pred_probs = torch.softmax(custom_image_pred, dim=1)
print(f"Prediction probabilities: {custom_image_pred_probs}")
# Convert prediction probabilities -> prediction labels
custom_image_pred_label = torch.argmax(custom_image_pred_probs, dim=1)
print(f"Prediction label: {custom_image_pred_label}")
python
Prediction logits: tensor([[ 0.1159, 0.0208, -0.1422]], device='cuda:0')
Prediction probabilities: tensor([[0.3729, 0.3391, 0.2881]], device='cuda:0')
Prediction label: tensor([0], device='cuda:0')
python
# Find the predicted label
custom_image_pred_class = class_names[custom_image_pred_label.cpu()] # put pred label to CPU, otherwise will error
custom_image_pred_class
python
'pizza'
- 将自定义图像预测放在一起:构建函数
将上面步骤全部放在一个函数中,以便我们可以轻松地反复使用。
(1)获取目标图像路径并转换为适合我们模型的正确数据类型 ( torch.float32 )。
(2)确保目标图像像素值在 [0, 1] 范围内。
(3)如有必要,变换目标图像。
(4)确保该模型位于目标设备上。
(5)使用经过训练的模型对目标图像进行预测(确保图像大小正确且与模型位于同一设备上)。
(6)将模型的输出 logits 转换为预测概率。
(7)将预测概率转换为预测标签。
(8)将目标图像与模型预测和预测概率一起绘制。
python
def pred_and_plot_image(model: torch.nn.Module,
image_path: str,
class_names: List[str] = None,
transform=None,
device: torch.device = device):
"""Makes a prediction on a target image and plots the image with its prediction."""
# 1. Load in image and convert the tensor values to float32
target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
# 2. Divide the image pixel values by 255 to get them between [0, 1]
target_image = target_image / 255.
# 3. Transform if necessary
if transform:
target_image = transform(target_image)
# 4. Make sure the model is on the target device
model.to(device)
# 5. Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
# Add an extra dimension to the image
target_image = target_image.unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(target_image.to(device))
# 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# 7. Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# 8. Plot the image alongside the prediction and prediction probability
plt.imshow(target_image.squeeze().permute(1, 2, 0)) # make sure it's the right size for matplotlib
if class_names:
title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
else:
title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
plt.title(title)
plt.axis(False);
python
# Pred on our custom image
pred_and_plot_image(model=model_1,
image_path=custom_image_path,
class_names=class_names,
transform=custom_image_transform,
device=device)