python
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1)
self.bn = nn.BatchNorm2d(8)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.conv(x)
# x = self.bn(x)
# x = self.relu(x)
# x = self.pool(x)
return x
if __name__ == '__main__':
# 设置 CPU 张量的随机数种子
torch.manual_seed(42)
# 创建模型实例
model = SimpleCNN()
# 加载并预处理图片
img_path = r'E:\photo\123.jpg'
img = Image.open(img_path).convert('RGB') # 读取的默认格式为 RGB,这里可去掉 convert()
preprocess = transforms.Compose([transforms.Resize((960, 960)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
img_tensor = preprocess(img).unsqueeze(0) # (1, C, H, W)
# 不计算梯度,进行一次前向传播
with torch.no_grad():
output = model(img_tensor)
# 模型输出的图片大小
print("Output size after conv layer:", output.size())
# 可视化原始图片
plt.imshow(img)
plt.title("Original Image")
plt.axis('off')
plt.show()
# 可视化卷积层后的图片
for i in range(output.size()[1]):
plt.subplot(output.size()[1]//4, 4, i+1)
plt.imshow(output[0, i, :, :].cpu().detach().numpy())
plt.axis('off')
plt.tight_layout()
plt.subplots_adjust(hspace=0.05)
plt.suptitle('After Conv2d')
plt.show()
原图大小为:(872, 1280, 3)
原图如下所示:
图片经过卷积层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])
图片经过卷积层后,得到的特征图如下所示:
图片经过 BN 层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])
图片经过 BN 层后,得到的特征图如下所示:
图片经过 ReLU 层后,得到的特征图大小为:torch.Size([1, 8, 960, 960])
图片经过 ReLU 层后,得到的特征图如下所示:
图片经过最大池化层后,得到的特征图大小为:torch.Size([1, 8, 480, 480])
图片经过最大池化层后,得到的特征图如下所示:
PIL 中的 Image.open(img_path) 读取的图片维度为 (W, H, C),读取的图片模式默认为 RGB;
OpenCV 中的 cv2.imread(img_path) 读取的图像维度为 (H, W, C),读取的图片模式默认为 BGR;
Image 图像数据转换为 np.ndarray 时,格式会从 (W, H, C) 转换为 (H, W, C)。
transforms.Resize((960, 960):旨在改变图像的大小,默认使用双线性插值(Bilinear);还支持最近邻插值(Nearest)、双三次插值(Bicubic)。
transforms.ToTensor():将 PIL 图像或 Numpy 数组中的整数像素值转换为 torch.FloatTensor 类型的浮点数;如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 的数据类型是 np.uint8,则将像素值从 [0, 255] 归一化到 [0.0, 1.0],这是通过将每个像素值除以 255 来实现的;将 [H, W, C] 的图像格式转换为 [C, H, W] 的 tensor 格式。