Pytorch超分辨率模型实现与详细解释

下面我将提供一个完整的Pytorch超分辨率模型实现,并对每一行代码进行详细解释,包括所有引用的头文件。

import torch #导入pytorch库,用于构建和训练神经网络的主要框架

import torch.nn as nn #导入pytorch的神经网络模块-包含各种神经网络的函数

import torch.nn.functional as F #导入Pytorch的神经网络函数模块,--包含激活函数,损失函数等

import torch.utils.data import DataLoader #导入pytorch的数据加载工具,用于创建和管理数据加载器

from torchvision import datasets, transforms #导入torchvision的数据集和变换模块--提供常用数据集和图像预处理方法

import matplotlib.pyplot as plt #导入matplotlib 的pyplot模块-用于绘制图表和可视化结果

import numpy as np #导入numpy库,用于数值计算,特别是在处理图像数据时

import os #导入操作系统接口模块,用于处理文件的目录路径。

import time #导入时间模块--用于测量训练时间等。

设置设备GPU如果可用,否则CPU

torch.cuda.is_avaiable() 检查当前系统是否可用的CUDA GPU

#如果有,使用GPU加速计算,否则使用CPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1 定义ESPCN模型

class ESPCN(nn.Module):

#类的初始化方法,定义模型的结构

def init(self, upscale_factor, num_channels=1):

初始化ESPCN模型

参数:

upscale_factor 放大倍数

num_channels 输入图像的通道数,默认为1(灰度图)

#调用父类nn.Module 的初始化方法

super(ESPCN, self).init()

#第一个卷积层,提取特征

nn.Conv2d: 2D卷积层,用于处理图像数据

参数:输入通道数,输出通道数,卷积核大小,填充大小。

这里使用5x5卷积核,填充2保持空间尺寸不变。

self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=5, padding=2)

#第二卷积层:进一步处理特征

输入64通道,输出32通道,3x3卷积核,填充1保持尺寸。

self.conv2 = nn.Conv2d(64,32,kernel_size=3,padding=1)

#最后一个卷积层,生成放大的特征图

输出通道数为;num_channels (upscale_factor * 2 )

这是因为我们将通过像素重排pixel_shuffle 来提升分辨率

self.conv3 = nn.Conv2d(32, num_channels (upscale_factor * 2 ), kernel_szie=3, padding=1)

#像素重排操作,子像素卷积层

pixelshuffle 将形状为 C * r^2, H, W的张量

重新排列

self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

#定义向前传播过程,描述数据如何通过网络的各层

def forward(self, x):

向前传播

参数

x 输入的第分辨率图像,形状为(batch_size, num_channels, height, width)

返回: 高分辨率图像,形状为(batch_szie, num_channels, heightupscale_factor, width u p s c a l e _ f a c t o r )

#第一层卷积后使用tanh激活函数

torch.tanh 双曲正切激活函数,将值压缩到(-1,1)范围

x = torch.tanh(self.conv1(x))

#第二层卷积后使用tanh激活函数

x = torch.tanh(self.conv2(x))

#第三层卷积

x = self.conv3(x)

#应用像素重排操作,将通道维度转换为空间维度。

x = self.pixel_shuffle(x)

#使用sigmoid激活函数,将值压缩到(0,1)范围

#这是因为图像像素值通常在0-1之间

x = torch.sigmoid(x)

return x

2 准备数据

def prepare_data(batch_szie, upscale_factor, dataset_name='MNIST'):

准备训练和预测数据

参数

batch_szie 批处理大小

upscale_factor 放大倍数

dataset_name 数据集名称,默认为MNIST

返回:

训练和测试数据加载器

数据转换管道

transforms.Compose 将多个变换组合在一起

transform = transforms.Compose([

#transforms.ToTensor 将PIL图像或者numpy数组转换为Pytorch张量

#同时 将像素值从[0.255]缩放到[0,1]范围

transforms.ToTensor(),

#transforms.Normalize 对张量进行标准化

#参数为均值和标准差,这里将标准化到[-1,1]范围

transforms.Normalize(0.5, (0.5,))

])

#根据数据集名称选择不同的数据集

if dataset_name == 'MNIST':

#下载并加载MNIST训练数据集

#MNIST是一个手写数字数据集,包含60 000个训练样本和10 000个测试样本

train_dataset = datasets.MNIST(

root='./data' #数据存储路径

train=True, #加载训练集

download=True,#如果数据不存在则下载

transform = transform #应用上面定义的数据转换

)

#下载并加载MNIST测试数据集

test_dataset = datasets.MNIST(

root='./data',

train=False, #加载测试集

download=True,

transform=transform

)

else:

#可以在这里添加对其他数据集的支持

raise ValueError("不支持的数据集:{dataset_name}")

#创建训练数据加载起

DataLoader 包装数据集并提供批量加载,shuffling 等功能

train_loader = DataLoader(

train_dataset,

batch_size = batch_size, #每个批次的样本数量

shuffle=True, 每个epoch 开始打乱数据顺序

num_works = 2, 使用2个子进程加载数据

pin_memory=True #将数据固定在内存中,加速GPU传输

)

#创建测试数据加载器

test_loader = DataLoader(

test_dataset,

batch_size=batch_size,

shuffle=False, #测试时不需要打乱顺序

num_works=2,

pin_memory=True

)

#返回训练和测试数据加载器

return train_laoder, test_loader

3 训练函数

def train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor):

训练模型,

参数

model:要训练的模型

train_loader: 训练数据加载器

criterion 损失函数

optimizer: 优化器

num_epochs: 训练轮数

upscale_factor: 放大倍数

#设置模型为训练模式

#这会启用dropout和batch normalization 等训练特定行为

model.train()

#记录训练过程中的损失值

losses = []

#记录训练开始时间

start_time = time.time()

#循环遍历每个epoch

for epoch in range(num_epochs):

#初始化当前epoch的损失值

epoch_loss = 0

#遍历训练数据加载起中的每个批次

for batch_idx, (data, target) in enumerate(train_loader):

#将数据移动到相应的设备 GPU或者CPU

data = data.to(device)

#创建低分辨率输入

#首先将图像下采样,然后上采样回原始大小以模拟低分辨率图像

#F.interpolate: 对图像进行上采样或者下采样

#scale_factor= 1/upscale_factor 下采样比例

mode = 'bicubic' 使用双三次循环算法

align_corners= False: 差值算法参数

lr_data = F.interpolate(

data,

scale_factor=1/upscale_factor,

mode='bicubic',

align_corners = False

)

#将下采样后的图像上采样回原始尺寸

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode = 'bicubic',

align_corners=False

)

#清零梯度

在pytorch 中,梯度是累加的,所以在每个批次开始时需要清零

optimizer.zero_grad()

#前向传播,降低分辨率图像输入模型,得到高分辨率输出

output = model(lr_data)

#计算损失,比较模型输出和原始高分辨率图像

loss = criterion(output, data)

#反向传播,计算梯度

loss.backward()

#更新权重,根据梯度调整模型参数

optimizer.step()

#累加当前批次的损失值

epoch_loss += loss.item()

#计算当前epoch的平均损失

losses.append(avg_loss)

#打印训练进度

if (epoch + 1) % 5 == 0:

#计算已用时间

elapsed_time = time.time() - start_time

#打印当前epoch,总epoch数,损失值和已用时间

#训练完后,绘制损失曲线

plt.figure(figsize=(10,5))

plt.plot(losses)

plt.title('Training loss over epochs')

plt.xlabel('Epoch')

plt.ylabel('Loss')

plt.grid(True)

#保存损失曲线图像

plt.savefig('training_loss.png')

plt.show()

#打印总训练时间

total_time = time.time() - start_time

4 测试函数

#定义模型测试函数

def test_model(model, test_loader, upscale_factor, num_examples=5):

测试模型并显示结果

参数:

model: 要测试的模型

test_loader: 测试数据加载器

upscale_factor: 放大倍数

num_examples: 要显示的示例数量

#设置模型为评估模式

#这会禁用dropout和batch normalization等训练特定行为

model.eval()

#初始化示例计数器

examples_shown = 0

#不计算梯度,节省内存和计算资源

with torch.no_grad():

#遍历测试数据加载器

for i, (data, target) in enumerate(test_loader):

#如果已经显示了足够多的示例,退出循环

if examples_shown >= num_examples:

break

#将数据移动到相应设备

data = data.to(device)

#创建低分辨率输入(与训练时间相同的方法)

lr_data = F.interpolate(

data,

scale_factor = 1/upscale_factor,

model='bicubic',

align_corners = False

)

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode = 'bicubic',

align_corners = False

)

#生成高分辨率图像

hr_output = model(lr_data)

#将图像移回CPU并转换为numpy数组以便显示

lr_image = lr_data[0].cpu().sequeeze().numpy()

hr_image = hr_output[0.cpu().sequeeze().numpy()

original_image = data[0].cpu().squeeze().numpy()

#显示结果

plt.figure(figsize=(12,4))

#显示低分辨率输入图像

plt.subplot(1,3,1)

plt.imshow(lr_image, cmap='gray')

plt.title('Low Resolution Input')

plt.axis('off')

#显示超分辨率输出图像

plt.subplot(1,3,2)

plt.imshow(hr_iamge, cmap='gray')

plt.title('Super Resolution Output')

plt.axis('off')

#显示原始高分辨率图像

plt.subplot(1,3,3)

plt.imshow(original_image, cmap='gray')

plt.title('Original high Resolution')

plt.axis('off')

#保存对比图像

plt.savefig(f'comparsion_example_{examples_shown+1}.png')

plt.show()

#增加示例计数器

examples_shown += 1

5 计算PSNR指标函数

def calculate_psnr(model, test_loader, upscale_factor):

计算模型的峰值信噪比PSNR

参数model要评估的模型

test_loader :测试数据加载器

upscale_factor 放大倍数

返回:平均PSNR值

#设置模型为评估模式

moel.eval()

#初始化PSNR总和和样本计数

total_psnr=0.0

total_samples=0

#不计算梯度

with torch.no_grad():

#遍历测试数据加载器

for data, _ in test_loader:

#将数据移动到相应设备

data = data.to(device)

#创建低分辨率输入

lr_data = F.interpolate(

data,

scale_factor=1/upscale_factor,

mode = 'bicubic',

align_corners = False

)

lr_data = F.interpolate(

lr_data,

scale_factor = upscale_factor,

mode='bicubic',

align_corners=False

)

#生成高分辨率图像

hr_output=model(lr_data)

#计算每个样本PSNR

for i in range(data.size(0)):

#将张量转换为numpy数组

original=data[i].cpu().numpy()

reconstructed=hr_output[i].cpu().numpy()

#计算均方差误差(MSE)

mse = np.mean((original - reconstructed) ** 2)

#避免除以零

if mse == 0:

psnr = 100 #无穷大的PSNR,这里设为100

else :

#计算PSNR 20 log10(MAX) - 10 l o g 1 0 ( MSE)

#对于[0,1]范围的图像 MAX = 1

psnr=20 np.log10(1.0) - 10 n p . l o g 1 0 ( mse)

#累加PSNR

total_psnr += psnr

total_samples += 1

#计算平均PSNR

avg_psnr = total_psnr / total_samples

return avg_psnr

6 主函数

#定义主函数,组织整个训练和测试流程

def main():

#超参数设置

upscale_factor=2 #放大倍数

num_epochs = 20 #训练轮数

batch_szie = 64 #批处理大小

learning_rate = 0.001 学习率

#创建输出目录,如果不存在

if not os.path.exists('results'):

os.makedirs('results')

#准备数据

train_loader, test_loader = prepare_data(batch_size, upscale_factor)

#初始化模型

model = ESPCN(upscale_factor=upscale_factor).to(device)

#打印模型结构

print(model)

#计算模型参数数量

total_params = sum(p.numel() for p in model.parameters())

#定义损失函数 - 均方误差损失

criterion = nn.MSELoss()

#定义优化器,Adam优化器

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

#训练模型

train_model(model, train_loader, criterion, optimizer, num_epochs, upscale_factor)

#测试模型

test_model(model, test_loader, upscale_factor)

#计算PSNR

calculate_psnr(model, test_loader, upscale_factor)

#保存模型

model_path='results/espcn_model.pth'

torch.save(model.state_dcit(), model_path)

if name=="main":

main()

头文件解释总结

1 torch:pytorch主库,提供张量操作和自动求导功能

2 torch.nn:pytorch神经网络模块,包含各种层和损失函数

3 torch.nn.functional:pytorch函数式接口,包含激活函数,损失函数等

4 torch.utils.data pytorch视觉库,提供常用数据集和图像变换

5 matplotlib.pyplot:会图库,用于可视化结果

6 torchvision pytorch视觉库,提供常用数据集和图像变换

7 numpy 数据计算库,用于处理数组数据

8 os 操作系统接口,用于处理文件和目录

9 time时间模块,用于测量运动时间

相关推荐
makerjack0011 小时前
Java中使用Spring Boot+Ollama实现本地AI的MCP接入
java·人工智能·spring boot
陈敬雷-充电了么-CEO兼CTO1 小时前
深度拆解判别式推荐大模型RankGPT!生成式精排落地提速94.8%,冷启动效果飙升,还解决了传统推荐3大痛点
大数据·人工智能·机器学习·chatgpt·大模型·推荐算法·agi
灰阳阳1 小时前
替身演员的艺术:pytest-mock 从入门到飙戏
自动化测试·python·pytest·unit testing·pytest-mock
stbomei1 小时前
生成式 AI 的 “魔法”:以 GPT 为例,拆解大语言模型(LLM)的训练与推理过程
人工智能
有才不一定有德1 小时前
多代理系统架构:Supervisor 与 Swarm 架构详解
人工智能·chatgpt·架构·系统架构
测试19981 小时前
单元测试到底是什么?该怎么做?
自动化测试·软件测试·python·测试工具·职场和发展·单元测试·测试用例
计算机sci论文精选3 小时前
CVPR 强化学习模块深度分析:连多项式不等式+自驾规划
人工智能·深度学习·机器学习·计算机视觉·机器人·强化学习·cvpr
华略创新4 小时前
用KPI导航数字化转型:制造企业如何科学评估系统上线成效
人工智能·制造·crm·管理系统·erp·软件·mes
嘀咕博客4 小时前
Komo Searc-AI驱动的搜索引擎
人工智能·搜索引擎·ai工具
小马过河R5 小时前
GPT-5原理
人工智能·gpt·深度学习·语言模型·embedding