MambaIRv2-大尺寸SR分块推理脚本

针对大尺寸的图像进行超分容易爆显存,可以通过预估显卡显存,选择合适的尺寸先分块再拼接的方式处理,针对拼接边缘容易存在的分割线条,通过线性渐变加权重叠区域缓解拼接痕迹。

推理脚本

本文提供一个基于MambaIRv2模型的在线推理脚本,对大尺寸图像进行分块超分处理的辅助脚本,方便设备显存有限的同学推理本地图像使用。
MambaIR源代码仓库地址
Mamba环境安装

以BaseSR为例:

python 复制代码
import torch
import torch.nn as nn
import os
import numpy as np
from tqdm import tqdm
from PIL import Image
from analysis.model_zoo.mambairv2 import buildMambaIRv2Base
from basicsr.utils import FileClient, imfrombytes, img2tensor, tensor2img, imwrite
import cv2
from typing import Tuple

def split_image(img: np.ndarray, tile_size: int = 512, overlap: int = 64) -> Tuple[list, list]:
    """
    将大图像分割成重叠的小块
    
    Args:
        img: 输入图像 (H, W, C)
        tile_size: 分块大小
        overlap: 重叠区域大小
    
    Returns:
        tiles: 分块列表
        positions: 每个分块在原图中的位置 (x_start, y_start, x_end, y_end)
    """
    h, w = img.shape[:2]
    tiles = []
    positions = []
    
    stride = tile_size - overlap
 
    y_steps = max(1, (h - overlap) // stride + (1 if (h - overlap) % stride > 0 else 0))      # 计算需要多少个分块
    x_steps = max(1, (w - overlap) // stride + (1 if (w - overlap) % stride > 0 else 0))
    
    for y in range(y_steps):
        for x in range(x_steps):

            y_start = max(0, y * stride)	# 计算分块位置
            x_start = max(0, x * stride)
            y_end = min(h, y_start + tile_size)
            x_end = min(w, x_start + tile_size)
            
            if y_end - y_start < tile_size:	# 调整起始位置确保分块大小一致
                y_start = max(0, y_end - tile_size)
            if x_end - x_start < tile_size:
                x_start = max(0, x_end - tile_size)
            
            tile = img[y_start:y_end, x_start:x_end]	# 提取分块
            tiles.append(tile)
            positions.append((x_start, y_start, x_end, y_end))
    
    return tiles, positions

def merge_tiles(tiles: list, positions: list, original_shape: Tuple[int, int], 
                overlap: int = 64) -> np.ndarray:
    """
    将处理后的分块合并回原图,使用加权融合消除接缝
    
    Args:
        tiles: 处理后的分块列表
        positions: 每个分块在原图中的位置
        original_shape: 原图尺寸 (H, W, C)
        overlap: 重叠区域大小
    
    Returns:
        merged_img: 合并后的图像
    """
    h, w = original_shape[:2]
    c = tiles[0].shape[2] if len(tiles[0].shape) == 3 else 1
    
    if c == 1:
        result = np.zeros((h, w), dtype=np.float32)
        weight = np.zeros((h, w), dtype=np.float32)
    else:
        result = np.zeros((h, w, c), dtype=np.float32)
        weight = np.zeros((h, w, c), dtype=np.float32)
    
    def create_weight_mask(size: int, overlap: int) -> np.ndarray:
        mask = np.ones((size, size), dtype=np.float32)          # 创建权重掩码,中心权重高,边缘权重低
        
        # 为重叠区域创建线性渐变权重
        if overlap > 0:
            for i in range(overlap):
                weight_val = i / overlap
                mask[i, :] *= weight_val
                mask[-(i+1), :] *= weight_val
                mask[:, i] *= weight_val
                mask[:, -(i+1)] *= weight_val
        return mask
    
    for tile, (x_start, y_start, x_end, y_end) in zip(tiles, positions):
        tile_h, tile_w = tile.shape[:2]
        
        weight_mask = create_weight_mask(tile_h, overlap)
        
        if c > 1:
            weight_mask = np.repeat(weight_mask[:, :, np.newaxis], c, axis=2)
        
        result[y_start:y_end, x_start:x_start+tile_w] += tile * weight_mask  # 将分块和权重添加到结果中
        weight[y_start:y_end, x_start:x_start+tile_w] += weight_mask
    
    weight[weight == 0] = 1
    merged_img = result / weight
    merged_img = np.clip(merged_img * 255, 0, 255).astype(np.uint8)
    
    return merged_img

def process_large_image(model, img: np.ndarray, tile_size: int = 512, 
                        overlap: int = 64, batch_size: int = 1) -> np.ndarray:
    """
    处理大图像的分块推理
    
    Args:
        model: 推理模型
        img: 输入图像 (H, W, C) [0-255]
        tile_size: 分块大小
        overlap: 重叠区域大小
        batch_size: 批处理大小
    
    Returns:
        output_img: 超分辨率后的图像
    """

    img_float = img.astype(np.float32) / 255.0
    tiles, positions = split_image(img_float, tile_size, overlap)       # 分割图像
    
    print(f"图像被分割为 {len(tiles)} 个分块")
    print(f"每个分块尺寸: {tiles[0].shape}")
    
    processed_tiles = []
    
    with torch.no_grad():
        # 使用tqdm显示处理进度
        for i in tqdm(range(0, len(tiles), batch_size), desc="处理分块", unit="batch"):
            batch_tiles = tiles[i:i+batch_size]
            
            batch_tensors = []
            for tile in batch_tiles:
                # 转换为tensor
                tensor = img2tensor(tile, bgr2rgb=True, float32=True)
                tensor = tensor.unsqueeze(0).cuda()
                batch_tensors.append(tensor)
            
            batch_tensor = torch.cat(batch_tensors, dim=0)
            outputs = model(batch_tensor)
            
            for j in range(outputs.shape[0]):
                output_tensor = outputs[j:j+1]
                output_img = tensor2img([output_tensor.detach().cpu()])
                processed_tiles.append(output_img)
    
    upscale = 4  # 根据模型设置
    original_h, original_w = img.shape[:2]
    sr_h, sr_w = original_h * upscale, original_w * upscale
    
    sr_positions = []
    for (x_start, y_start, x_end, y_end) in positions:
        sr_positions.append((
            x_start * upscale, 
            y_start * upscale, 
            x_end * upscale, 
            y_end * upscale
        ))
    
    sr_tiles = []
    for tile in processed_tiles:
        sr_tiles.append(tile.astype(np.float32) / 255.0)
    
    print("合并分块...")
    merged_img = merge_tiles(sr_tiles, sr_positions, (sr_h, sr_w, 3), overlap * upscale)
    
    return merged_img

if __name__ == "__main__":
    model = buildMambaIRv2Base(upscale=4).cuda()
    ckpt_path = './experiments/pretrained_models/mambairv2_classicSR_Base_x4.pth'
    upload_folder = '/root/dn_sr_pics/test_pics/'
    result_folder = './results'
    
    print("加载模型权重...")
    checkpoint = torch.load(ckpt_path)
    if 'params' in checkpoint:
        model.load_state_dict(checkpoint['params'])
    else:
        model.load_state_dict(checkpoint)
    model.eval()
    
    os.makedirs(result_folder, exist_ok=True)
    
    file_list = [f for f in os.listdir(upload_folder) 
                 if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
    
    if not file_list:
        print(f"在 {upload_folder} 中没有找到图像文件")
    else:
        print(f"找到 {len(file_list)} 个图像文件")
    
    for file_idx, file_name in enumerate(file_list, 1):
        file_path = os.path.join(upload_folder, file_name)
        print(f"\n[{file_idx}/{len(file_list)}] 处理: {file_name}")
        
        img = cv2.imread(file_path)
        if img is None:
            print(f"  无法读取图像: {file_path}")
            continue
        
        h, w = img.shape[:2]
        print(f"  原始图像尺寸: {w}x{h}")
        
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3  # GB
        current_memory = torch.cuda.memory_allocated() / 1024**3  # GB
        
        if gpu_memory < 8:  # 小于8GB显存
            tile_size = 256
            overlap = 32
        elif gpu_memory < 12:  # 小于12GB显存
            tile_size = 384
            overlap = 48
        else:  # 12GB以上显存
            tile_size = 512
            overlap = 64
        
        max_dim = max(h, w)
        if max_dim > tile_size:
            print(f"  图像较大,使用分块处理 (tile_size={tile_size}, overlap={overlap})")
            
            # 分块处理
            output_img = process_large_image(
                model, img, 
                tile_size=tile_size, 
                overlap=overlap,
                batch_size=1  # 根据显存调整
            )
        else:
            print("  图像较小,直接处理")
            
            img_float = img.astype(np.float32) / 255.
            img_tensor = img2tensor(img_float, bgr2rgb=True, float32=True)
            img_tensor = img_tensor.unsqueeze(0).cuda()
            
            with torch.no_grad():
                output = model(img_tensor)
            
            output_img = tensor2img([output.detach().cpu()])
        
        save_path = os.path.join(result_folder, file_name)
        cv2.imwrite(save_path, output_img)
        print(f"  结果保存到: {save_path}")
        
        torch.cuda.empty_cache()
    
    print("\n所有图像处理完成!")

MambaIRv2 超分x4倍处理效果:

相关推荐
ytttr8734 天前
MATLAB基于CNN的图像超分辨率重建实现
matlab·cnn·超分辨率重建
fie888910 天前
MATLAB中基于CNN实现图像超分辨率重建
matlab·cnn·超分辨率重建
聊天QQ:2769988518 天前
搭建Matlab风光柴储混合微电网储能电池系统互补能量管理Simulink模型
超分辨率重建
kaikaile19952 个月前
基于MATLAB的传统插值法实现超分辨率重建
人工智能·matlab·超分辨率重建
浆果02072 个月前
【图像超分】论文复现:轻量化超分 | RLFN的Pytorch源码复现,跑通源码,整合到EDSR-PyTorch中进行训练、测试
人工智能·python·深度学习·超分辨率重建·1024程序员节
这张生成的图像能检测吗3 个月前
(论文速读)Prompt-Free Diffusion:告别提示工程的烦恼
人工智能·深度学习·计算机视觉·prompt·图像生成·超分辨率重建·clip
ytttr8735 个月前
MATLAB 实现 SRCNN 图像超分辨率重建
开发语言·matlab·超分辨率重建
pk_xz1234565 个月前
光电二极管探测器电流信号处理与指令输出系统
人工智能·深度学习·数学建模·数据挖掘·信号处理·超分辨率重建
pk_xz1234565 个月前
厌氧菌数据挖掘可行性评估报告
人工智能·深度学习·神经网络·数学建模·数据挖掘·超分辨率重建