【天文】星光超分辨图像增强

天文望远镜是天文学家的另一双眼睛,其空间分辨率和空间采样率是两个重要指标。传统光学望远镜的分辨率受到光的衍射极限的物理限制,而禁止人们通过望远镜获取无限细节。 STAR,这是一个大规模的天文SR数据集,包含54,738个通量一致的星场图像对,覆盖了广泛的天体区域。这些对将哈勃太空望远镜的高分辨率观测与通过保存通量的数据生成管道生成的物理忠实的低分辨率对应物相结合,从而能够系统地开发场级 ASR 模型。

超分辨率(SR)技术通过实现经济高效的高分辨率图像捕获,推动了天文成像的发展,这对于探测遥远的天体和进行精确的结构分析至关重要。然而,现有的天文超分辨率(ASR)数据集存在三个关键局限性:通量不一致、目标裁剪设置以及数据多样性不足,这严重阻碍了ASR的发展。

STAR:天文星场超分辨率的基准

STAR是一个大规模的天文超分辨率数据集,包含54,738对通量一致的星场图像对,覆盖了广阔的天区。这些图像对将哈勃太空望远镜的高分辨率观测数据与通过通量保留数据生成管道生成的物理上可信的低分辨率对应图像相结合,从而能够系统地开发场级ASR模型。为了进一步推动ASR领域的发展,STAR提供了一种新颖的通量误差(FE)指标,用于从物理角度评估SR模型。利用这个基准,提出了一种通量不变超分辨率(FISR)模型,该模型可以根据输入的光度信息准确推断出通量一致的高分辨率图像,在新设计的通量一致性指标上比几种SR最先进方法高出24.84%,显示在天体物理学中的优势。大量实验证明了提出的方法的有效性和数据集的价值。

复制代码
from .common import *
# import common
import time
import torch
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
from pdb import set_trace as stx
import numbers


from einops.layers.torch import Rearrange
import time
from . import MODEL
from .base_model import Base_Model
from .model_init import *
from einops import repeat, rearrange
import torch.nn.functional as F

url = {
    'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
    'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
    'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
    'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
    'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
    'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
}

# def make_model(args, parent=False):
#     return EDSR(args)

@MODEL.register
class EDSR(Base_Model):
    def __init__(self,  
                
                n_resblocks=32,
                n_feats =64,
                scale= 2,
                res_scale = 0.1,
                n_colors=1,
                rgb_range = 255,
                **kwargs):
        super(EDSR, self).__init__(**kwargs)

        self.n_resblocks = n_resblocks
        self.n_feats = n_feats
        kernel_size = 3 
        self.scale = scale
        self.res_scale = res_scale
        self.n_colors = n_colors
        self.rgb_range = rgb_range
        conv=default_conv

        act = nn.ReLU(True)
        url_name = 'r{}f{}x{}'.format(self.n_resblocks, self.n_feats, self.scale)
        if url_name in url:
            self.url = url[url_name]
        else:
            self.url = None
        self.sub_mean = MeanShift(self.rgb_range)
        self.add_mean = MeanShift(self.rgb_range, sign=1)

        # define head module
        m_head = [conv(self.n_colors, self.n_feats, kernel_size)]

        # define body module
        m_body = [
            ResBlock(
                conv, self.n_feats, kernel_size, act=act, res_scale=self.res_scale
            ) for _ in range(self.n_resblocks)
        ]
        m_body.append(conv(self.n_feats, self.n_feats, kernel_size))

        # define tail module
        m_tail = [
            Upsampler(conv, self.scale, self.n_feats, act=False),
            conv(self.n_feats, self.n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x,targets):
        # x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)
        pred_img = x
        # x = self.add_mean(x)
        if self.training:
            # 提取 targets 中的数据
            attn_map = targets['attn_map']
            mask_float = targets['mask']
            attn_map = torch.nan_to_num(attn_map, nan=0.0)
            # 计算 L1 损失
            l1_loss = (torch.abs(pred_img - targets['hr']) * mask_float).sum() / (mask_float.sum() + 1e-3)
            weighted_diff = torch.abs(pred_img - targets['hr']) * attn_map
            flux_loss = weighted_diff.sum() / (attn_map.sum() + 1e-3)
            total_loss = l1_loss + 0.01 * flux_loss
            losses = dict(l1_loss=l1_loss, flux_loss=0.01*flux_loss)
            return total_loss, losses
        else:
            return dict(pred_img = pred_img)
        # return x 

损失函数非常巧妙,它不是一个简单的L1或L2损失,而是一个复合损失函数 ,由两部分加权组成:一个带掩码(Masked)的L1损失 和一个带注意力权重(Attention-weighted)的L1损失

GuoCheng12/STAR: This repo is used for super-resolution in astronomy.

天文学家的修图技术,可比AI厉害多了

相关推荐
春末的南方城市5 小时前
苏大团队联合阿丘科技发表异常生成新方法:创新双分支训练法,同步攻克异常图像生成、分割及下游模型性能提升难题。
人工智能·科技·深度学习·计算机视觉·aigc
WeiJingYu.6 小时前
P3.7计算机视觉
人工智能·opencv·计算机视觉
春末的南方城市8 小时前
AI视频生成进入多镜头叙事时代!字节发布 Waver 1.:一句话生成 10 秒 1080p 多风格视频,创作轻松“一键”达!
人工智能·深度学习·机器学习·计算机视觉·aigc
春末的南方城市12 小时前
阿里开源视频修复方法Vivid-VR:以独特策略与架构革新,引领生成视频修复高质量可控新时代。
人工智能·深度学习·机器学习·计算机视觉·aigc
sali-tec15 小时前
C# 基于halcon的视觉工作流-章40-OCR训练识别
开发语言·图像处理·算法·计算机视觉·c#·ocr
CoovallyAIHub18 小时前
机器人“大脑”遭遇认知冻结攻击!复旦等提出FreezeVLA,一张图片即可瘫痪多模态大模型
深度学习·算法·计算机视觉
tirvideo19 小时前
RK3588芯片与板卡全面解析:旗舰级AIoT与边缘计算的核心
人工智能·嵌入式硬件·深度学习·目标检测·机器学习·计算机视觉·边缘计算
程序猿小D19 小时前
【完整源码+数据集+部署教程】【智慧工地监控】建筑工地设备分割系统: yolov8-seg-efficientViT
python·yolo·计算机视觉·数据集·yolov8·yolo11·建筑工地设备分割系统
红苕稀饭66621 小时前
RISE论文阅读
论文阅读·人工智能·计算机视觉
lovod1 天前
【视觉SLAM十四讲】视觉里程计 1
人工智能·线性代数·计算机视觉·矩阵·机器人