下面我将提供一个完整的Pytorch超分辨率模型实现,并对每一行代码进行详细解释,包括所有引用的头文件。
import torch #导入pytorch库,用于构建和训练神经网络的主要框架
import torch.nn as nn #导入pytorch的神经网络模块-包含各种神经网络的函数
import torch.nn.functional as F #导入Pytorch的神经网络函数模块,--包含激活函数,损失函数等
import torch.utils.data import DataLoader #导入pytorch的数据加载工具,用于创建和管理数据加载器
from torchvision import datasets, transforms #导入torchvision的数据集和变换模块--提供常用数据集和图像预处理方法
import matplotlib.pyplot as plt #导入matplotlib 的pyplot模块-用于绘制图表和可视化结果
import numpy as np #导入numpy库,用于数值计算,特别是在处理图像数据时
import os #导入操作系统接口模块,用于处理文件的目录路径。
import time #导入时间模块--用于测量训练时间等。
设置设备GPU如果可用,否则CPU
torch.cuda.is_avaiable() 检查当前系统是否可用的CUDA GPU
#如果有,使用GPU加速计算,否则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1 定义ESPCN模型
class ESPCN(nn.Module):
#类的初始化方法,定义模型的结构
def init(self, upscale_factor, num_channels=1):
初始化ESPCN模型
参数:
upscale_factor 放大倍数
num_channels 输入图像的通道数,默认为1(灰度图)
#调用父类nn.Module 的初始化方法
super(ESPCN, self).init()
#第一个卷积层,提取特征
nn.Conv2d: 2D卷积层,用于处理图像数据
参数:输入通道数,输出通道数,卷积核大小,填充大小。
这里使用5x5卷积核,填充2保持空间尺寸不变。
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=5, padding=2)
#第二卷积层:进一步处理特征
输入64通道,输出32通道,3x3卷积核,填充1保持尺寸。
self.conv2 = nn.Conv2d(64,32,kernel_size=3,padding=1)
#最后一个卷积层,生成放大的特征图
输出通道数为;num_channels (upscale_factor * 2 )
这是因为我们将通过像素重排pixel_shuffle 来提升分辨率
self.conv3 = nn.Conv2d(32, num_channels (upscale_factor * 2 ), kernel_szie=3, padding=1)
#像素重排操作,子像素卷积层
pixelshuffle 将形状为 C * r^2, H, W的张量
重新排列
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
#定义向前传播过程,描述数据如何通过网络的各层
def forward(self, x):
向前传播
参数
x 输入的第分辨率图像,形状为(batch_size, num_channels, height, width)
返回: 高分辨率图像,形状为(batch_szie, num_channels, heightupscale_factor, width u p s c a l e _ f a c t o r )
#第一层卷积后使用tanh激活函数
torch.tanh 双曲正切激活函数,将值压缩到(-1,1)范围
x = torch.tanh(self.conv1(x))
#第二层卷积后使用tanh激活函数
x = torch.tanh(self.conv2(x))
#第三层卷积
x = self.conv3(x)
#应用像素重排操作,将通道维度转换为空间维度。
x = self.pixel_shuffle(x)
#使用sigmoid激活函数,将值压缩到(0,1)范围
#这是因为图像像素值通常在0-1之间
x = torch.sigmoid(x)
return x
2 准备数据
def prepare_data(batch_szie, upscale_factor, dataset_name='MNIST'):
准备训练和预测数据
参数
batch_szie 批处理大小
upscale_factor 放大倍数
dataset_name 数据集名称,默认为MNIST
返回:
训练和测试数据加载器
数据转换管道
transforms.Compose 将多个变换组合在一起
transform = transforms.Compose([
#transforms.ToTensor 将PIL图像或者numpy数组转换为Pytorch张量
#同时 将像素值从[0.255]缩放到[0,1]范围
transforms.ToTensor(),
#transforms.Normalize 对张量进行标准化
#参数为均值和标准差,这里将标准化到[-1,1]范围
transforms.Normalize(0.5, (0.5,))
])
#根据数据集名称选择不同的数据集
if dataset_name == 'MNIST':
#下载并加载MNIST训练数据集
#MNIST是一个手写数字数据集,包含60 000个训练样本和10 000个测试样本
train_dataset = datasets.MNIST(
root='./data' #数据存储路径
train=True, #加载训练集
download=True,#如果数据不存在则下载
transform = transform #应用上面定义的数据转换
)
#下载并加载MNIST测试数据集
test_dataset = datasets.MNIST(
root='./data',
train=False, #加载测试集
download=True,
transform=transform
)
else:
#可以在这里添加对其他数据集的支持
raise ValueError("不支持的数据集:{dataset_name}")
#创建训练数据加载起
DataLoader 包装数据集并提供批量加载,shuffling 等功能
train_loader = DataLoader(
train_dataset,
batch_size = batch_size, #每个批次的样本数量
shuffle=True, 每个epoch 开始打乱数据顺序
num_works = 2, 使用2个子进程加载数据
pin_memory=True #将数据固定在内存中,加速GPU传输
)
#创建测试数据加载器
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False, #测试时不需要打乱顺序
num_works=2,
pin_memory=True
)
#返回训练和测试数据加载器
return train_laoder, test_loader
3 训练函数
def train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor):
训练模型,
参数
model:要训练的模型
train_loader: 训练数据加载器
criterion 损失函数
optimizer: 优化器
num_epochs: 训练轮数
upscale_factor: 放大倍数
#设置模型为训练模式
#这会启用dropout和batch normalization 等训练特定行为
model.train()
#记录训练过程中的损失值
losses = []
#记录训练开始时间
start_time = time.time()
#循环遍历每个epoch
for epoch in range(num_epochs):
#初始化当前epoch的损失值
epoch_loss = 0
#遍历训练数据加载起中的每个批次
for batch_idx, (data, target) in enumerate(train_loader):
#将数据移动到相应的设备 GPU或者CPU
data = data.to(device)
#创建低分辨率输入
#首先将图像下采样,然后上采样回原始大小以模拟低分辨率图像
#F.interpolate: 对图像进行上采样或者下采样
#scale_factor= 1/upscale_factor 下采样比例
mode = 'bicubic' 使用双三次循环算法
align_corners= False: 差值算法参数
lr_data = F.interpolate(
data,
scale_factor=1/upscale_factor,
mode='bicubic',
align_corners = False
)
#将下采样后的图像上采样回原始尺寸
lr_data = F.interpolate(
lr_data,
scale_factor = upscale_factor,
mode = 'bicubic',
align_corners=False
)
#清零梯度
在pytorch 中,梯度是累加的,所以在每个批次开始时需要清零
optimizer.zero_grad()
#前向传播,降低分辨率图像输入模型,得到高分辨率输出
output = model(lr_data)
#计算损失,比较模型输出和原始高分辨率图像
loss = criterion(output, data)
#反向传播,计算梯度
loss.backward()
#更新权重,根据梯度调整模型参数
optimizer.step()
#累加当前批次的损失值
epoch_loss += loss.item()
#计算当前epoch的平均损失
losses.append(avg_loss)
#打印训练进度
if (epoch + 1) % 5 == 0:
#计算已用时间
elapsed_time = time.time() - start_time
#打印当前epoch,总epoch数,损失值和已用时间
#训练完后,绘制损失曲线
plt.figure(figsize=(10,5))
plt.plot(losses)
plt.title('Training loss over epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
#保存损失曲线图像
plt.savefig('training_loss.png')
plt.show()
#打印总训练时间
total_time = time.time() - start_time
4 测试函数
#定义模型测试函数
def test_model(model, test_loader, upscale_factor, num_examples=5):
测试模型并显示结果
参数:
model: 要测试的模型
test_loader: 测试数据加载器
upscale_factor: 放大倍数
num_examples: 要显示的示例数量
#设置模型为评估模式
#这会禁用dropout和batch normalization等训练特定行为
model.eval()
#初始化示例计数器
examples_shown = 0
#不计算梯度,节省内存和计算资源
with torch.no_grad():
#遍历测试数据加载器
for i, (data, target) in enumerate(test_loader):
#如果已经显示了足够多的示例,退出循环
if examples_shown >= num_examples:
break
#将数据移动到相应设备
data = data.to(device)
#创建低分辨率输入(与训练时间相同的方法)
lr_data = F.interpolate(
data,
scale_factor = 1/upscale_factor,
model='bicubic',
align_corners = False
)
lr_data = F.interpolate(
lr_data,
scale_factor = upscale_factor,
mode = 'bicubic',
align_corners = False
)
#生成高分辨率图像
hr_output = model(lr_data)
#将图像移回CPU并转换为numpy数组以便显示
lr_image = lr_data[0].cpu().sequeeze().numpy()
hr_image = hr_output[0.cpu().sequeeze().numpy()
original_image = data[0].cpu().squeeze().numpy()
#显示结果
plt.figure(figsize=(12,4))
#显示低分辨率输入图像
plt.subplot(1,3,1)
plt.imshow(lr_image, cmap='gray')
plt.title('Low Resolution Input')
plt.axis('off')
#显示超分辨率输出图像
plt.subplot(1,3,2)
plt.imshow(hr_iamge, cmap='gray')
plt.title('Super Resolution Output')
plt.axis('off')
#显示原始高分辨率图像
plt.subplot(1,3,3)
plt.imshow(original_image, cmap='gray')
plt.title('Original high Resolution')
plt.axis('off')
#保存对比图像
plt.savefig(f'comparsion_example_{examples_shown+1}.png')
plt.show()
#增加示例计数器
examples_shown += 1
5 计算PSNR指标函数
def calculate_psnr(model, test_loader, upscale_factor):
计算模型的峰值信噪比PSNR
参数model要评估的模型
test_loader :测试数据加载器
upscale_factor 放大倍数
返回:平均PSNR值
#设置模型为评估模式
moel.eval()
#初始化PSNR总和和样本计数
total_psnr=0.0
total_samples=0
#不计算梯度
with torch.no_grad():
#遍历测试数据加载器
for data, _ in test_loader:
#将数据移动到相应设备
data = data.to(device)
#创建低分辨率输入
lr_data = F.interpolate(
data,
scale_factor=1/upscale_factor,
mode = 'bicubic',
align_corners = False
)
lr_data = F.interpolate(
lr_data,
scale_factor = upscale_factor,
mode='bicubic',
align_corners=False
)
#生成高分辨率图像
hr_output=model(lr_data)
#计算每个样本PSNR
for i in range(data.size(0)):
#将张量转换为numpy数组
original=data[i].cpu().numpy()
reconstructed=hr_output[i].cpu().numpy()
#计算均方差误差(MSE)
mse = np.mean((original - reconstructed) ** 2)
#避免除以零
if mse == 0:
psnr = 100 #无穷大的PSNR,这里设为100
else :
#计算PSNR 20 log10(MAX) - 10 l o g 1 0 ( MSE)
#对于[0,1]范围的图像 MAX = 1
psnr=20 np.log10(1.0) - 10 n p . l o g 1 0 ( mse)
#累加PSNR
total_psnr += psnr
total_samples += 1
#计算平均PSNR
avg_psnr = total_psnr / total_samples
return avg_psnr
6 主函数
#定义主函数,组织整个训练和测试流程
def main():
#超参数设置
upscale_factor=2 #放大倍数
num_epochs = 20 #训练轮数
batch_szie = 64 #批处理大小
learning_rate = 0.001 学习率
#创建输出目录,如果不存在
if not os.path.exists('results'):
os.makedirs('results')
#准备数据
train_loader, test_loader = prepare_data(batch_size, upscale_factor)
#初始化模型
model = ESPCN(upscale_factor=upscale_factor).to(device)
#打印模型结构
print(model)
#计算模型参数数量
total_params = sum(p.numel() for p in model.parameters())
#定义损失函数 - 均方误差损失
criterion = nn.MSELoss()
#定义优化器,Adam优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#训练模型
train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor)
#测试模型
test_model(model, test_loader, upscale_factor)
#计算PSNR
calculate_psnr(model, test_loader, upscale_factor)
#保存模型
model_path='results/espcn_model.pth'
torch.save(model.state_dcit(), model_path)
if name=="main":
main()
头文件解释总结
1 torch:pytorch主库,提供张量操作和自动求导功能
2 torch.nn:pytorch神经网络模块,包含各种层和损失函数
3 torch.nn.functional:pytorch函数式接口,包含激活函数,损失函数等
4 torch.utils.data pytorch视觉库,提供常用数据集和图像变换
5 matplotlib.pyplot:会图库,用于可视化结果
6 torchvision pytorch视觉库,提供常用数据集和图像变换
7 numpy 数据计算库,用于处理数组数据
8 os 操作系统接口,用于处理文件和目录
9 time时间模块,用于测量运动时间