SSF-CNN:空间光谱融合的卷积光谱图像超分网络

SSF-CNN: SPATIAL AND SPECTRAL FUSION WITH CNN FOR HYPERSPECTRAL IMAGE SUPER-RESOLUTION

文章目录

简介

​ 本文提出了一种利用空间和光谱进行高光谱融合图像超分辨率的新型CNN架构,首先是对高光谱图像进行双三次插值 ,使其空间分辨率大小和多光谱一致,然后进行concat操作 。使用类似于SRCNN的网络框架对融合超分的图像进行优化,最后输出高分辨率高光谱超分图像。

​ 对于PDCon,也就是引入了部分密集连接,将输入concat到每一个卷积层后面。
Hyperspectral-Image-Super-Resolution-Benchmark------光谱图像超分基准-CSDN博客
Paper : IEEE
Codehttps://github.com/miraclefan777/SSFCNN

解决问题

  1. 传统方法通过基于优化的方法恢复 HR-HS 图像的质量在很大程度上取决于预定义的约束。此外,由于约束项数量较多,优化过程通常涉及较高的计算成本。
  2. 执行HSI SR的一个直接想法是直接应用这样的网络来放大LR-HS图像的空间维度或HR-RGB图像的光谱维度,我们称之为Spatial-CNN和Spectral-CNN,这两种单图像方法忽略了两种图像特有的信息互补优势。

网络框架

  1. 原始的SRCNN是将图片映射到Ycbcr空间,并只使用其中的 Y 分量作为输入来预测 HR Y 图像,该论文则是将图片的通道信息以及空间信息整个进行输入
  2. 原始SRCNN卷积核大小第1,2修改为3*3,增加上下文信息,同时为了避免高维数据(padding为same,保持和原有特征图大小一致)

代码实现

python 复制代码
class SSFCNNnet(nn.Module):
    def __init__(self, num_spectral=31, scale_factor=8, pdconv=False):
        super(SSFCNNnet, self).__init__()
        self.scale_factor = scale_factor
        self.pdconv = pdconv

        self.Upsample = nn.Upsample(mode='bicubic', scale_factor=self.scale_factor)

        self.conv1 = nn.Conv2d(num_spectral + 3, 64, kernel_size=3, padding="same")
        if pdconv:
            self.conv2 = nn.Conv2d(64 + 3, 32, kernel_size=3, padding="same")
            self.conv3 = nn.Conv2d(32 + 3, num_spectral, kernel_size=5, padding="same")
        else:
            self.conv2 = nn.Conv2d(64, 32, kernel_size=3, padding="same")
            self.conv3 = nn.Conv2d(32, num_spectral, kernel_size=5, padding="same")

        self.relu = nn.ReLU(inplace=True)

    def forward(self, lr_hs, hr_ms):
        """
            :param lr_hs:LR-HSI低分辨率的高光谱图像
            :param hr_ms:高分辨率的多光谱图像
            :return:
        """
        # 对LR-HSI低分辨率图像进行上采样,让其分辨率更高
        lr_hs_up = self.Upsample(lr_hs)
        # 将上采样后的LR-HSI低分辨率图像与高分辨率的多光谱图像进行拼接
        x = torch.cat((lr_hs_up, hr_ms), dim=1)

        x = self.relu(self.conv1(x))
        if self.pdconv:
            x = torch.cat((x, hr_ms), dim=1)
            x = self.relu(self.conv2(x))
            x = torch.cat((x, hr_ms), dim=1)
        else:
            x = self.relu(self.conv2(x))

        out = self.conv3(x)
        return out

如果需要使用密集连接,只需要在初始化网络模型时,传参pdconv=True

训练部分

未提供自定义dataset类,根据自己的dateset进行参数的修改即可。

python 复制代码
import argparse
from calculate_metrics import Loss_SAM, Loss_RMSE, Loss_PSNR
from models.SSFCNNnet import SSFCNNnet
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from train_dataloader import CAVEHSIDATAprocess
from utils import create_F, fspecial,AverageMeter
import os
import copy
import torch
import torch.nn as nn

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="SSFCNNnet")
    parser.add_argument('--train-file', type=str, required=True)
    parser.add_argument('--eval-file', type=str, required=True)
    parser.add_argument('--outputs-dir', type=str, required=True)
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--num-workers', type=int, default=0)
    parser.add_argument('--num-epochs', type=int, default=400)
    parser.add_argument('--seed', type=int, default=123)
    args = parser.parse_args()

    assert args.model in ['SSFCNNnet', 'PDcon_SSF']

    outputs_dir = os.path.join(args.outputs_dir, '{}'.format(args.model))
    if not os.path.exists(outputs_dir):
        os.makedirs(outputs_dir)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    torch.manual_seed(args.seed)

    # 训练参数
    # loss_func = nn.L1Loss(reduction='mean').cuda()
    criterion = nn.MSELoss()


    #################数据集处理#################
    R = create_F()
    PSF = fspecial('gaussian', 8, 3)
    downsample_factor = 8
    training_size = 64
    stride = 32
    stride1 = 32

    train_dataset = CAVEHSIDATAprocess(args.train_file, R, training_size, stride, downsample_factor, PSF, 20)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)

    eval_dataset = CAVEHSIDATAprocess(args.eval_file, R, training_size, stride, downsample_factor, PSF, 12)
    eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

    #################数据集处理#################

    # 模型
    if args.model == 'SSFCNNnet':
        model = SSFCNNnet().cuda()
    else:
        model = SSFCNNnet(pdconv=True).cuda()

    best_weights = copy.deepcopy(model.state_dict())
    best_epoch = 0
    best_psnr = 0.0

    # 模型初始化
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    optimizer = torch.optim.Adam([
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
    ], lr=args.lr)

    start_epoch = 0


    for epoch in range(start_epoch, args.num_epochs):
        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
            t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs - 1))

            for data in train_dataloader:
                label, lr_hs, hr_ms = data

                label = label.to(device)
                lr_hs = lr_hs.to(device)
                hr_ms = hr_ms.to(device)
                lr = optimizer.param_groups[0]['lr']
                pred = model(hr_ms, lr_hs)
                loss = criterion(pred, label)

                epoch_losses.update(loss.item(), len(label))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg), lr='{0:1.8f}'.format(lr))
                t.update(len(label))

        # torch.save(model.state_dict(), os.path.join(outputs_dir, 'epoch_{}.pth'.format(epoch)))

        if epoch % 5 == 0:
            model.eval()
            val_loss = AverageMeter()

            SAM = Loss_SAM()
            RMSE = Loss_RMSE()
            PSNR = Loss_PSNR()

            sam = AverageMeter()
            rmse = AverageMeter()
            psnr = AverageMeter()

            for data in eval_dataloader:
                label, lr_hs, hr_ms = data
                lr_hs = lr_hs.to(device)
                hr_ms = hr_ms.to(device)
                label = label.cpu().numpy()

                with torch.no_grad():
                    preds = model(hr_ms, lr_hs).cpu().numpy()

                sam.update(SAM(preds, label), len(label))
                rmse.update(RMSE(preds, label), len(label))
                psnr.update(PSNR(preds, label), len(label))

            if psnr.avg > best_psnr:
                best_epoch = epoch
                best_psnr = psnr.avg
                best_weights = copy.deepcopy(model.state_dict())

            print('eval psnr: {:.2f}  RMSE: {:.2f}  SAM: {:.2f} '.format(psnr.avg, rmse.avg, sam.avg))

运行结果

相关推荐
小陈phd2 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao3 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
秀儿还能再秀6 小时前
神经网络(系统性学习三):多层感知机(MLP)
神经网络·学习笔记·mlp·多层感知机
ZHOU_WUYI7 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1237 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界8 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221518 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2518 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
浊酒南街9 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归
畅联云平台9 小时前
美畅物联丨智能分析,安全管控:视频汇聚平台助力智慧工地建设
人工智能·物联网