Pytorch图像去噪实战(四):Attention UNet图像去噪实战,让模型重点恢复边缘和纹理区域

Pytorch图像去噪实战(四):Attention UNet图像去噪实战,让模型重点恢复边缘和纹理区域


一、问题场景:模型降噪了,但重点区域恢复不够好

在实际图像去噪项目中,我发现一个很常见的问题:

模型整体去噪效果不错,但关键区域恢复不够好。

比如:

  • OCR图片中文字边缘变虚
  • 人脸图像中眼睛、头发纹理被抹平
  • 医学影像中细小结构不清晰
  • 老照片修复中衣服纹理丢失

这类问题不是简单提高模型层数就能解决的。

我一开始尝试加深 UNet、换 ResUNet,但提升有限。

后来发现根因是:

模型没有显式区分"哪里重要"。

普通卷积网络会平等处理整张图像,但图像去噪里,不同区域的重要性是不一样的。

因此这一篇我们引入:Attention UNet


二、为什么图像去噪需要注意力机制?

图像里不同区域的信息价值不同:

平坦区域

比如天空、墙面、背景。

这类区域主要目标是去掉噪声,保持平滑。

边缘区域

比如文字边缘、物体轮廓。

这类区域既要降噪,又不能模糊。

纹理区域

比如头发、布料、医学细节。

这类区域最容易被模型误认为噪声。

普通UNet通过 skip connection 传递浅层信息,但它不会判断哪些浅层信息更重要。

Attention模块的作用就是:

给重要区域更高权重,给无关区域更低权重。


三、Attention UNet核心思想

Attention UNet不是把注意力加在所有地方,而是通常加在 skip connection 上。

普通UNet:

text 复制代码
encoder feature -> concat -> decoder

Attention UNet:

text 复制代码
encoder feature -> attention gate -> concat -> decoder

这样做的好处是:

  • 减少无关噪声特征传递
  • 强化边缘和结构信息
  • 提升细节恢复能力

四、工程目录结构

复制代码
attention_unet_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── attention_unet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

五、Attention Gate模块实现

Attention Gate的输入一般有两个:

  • x:encoder传来的浅层特征
  • g:decoder当前的引导特征

核心流程:

  1. 对 x 和 g 做通道映射
  2. 相加后经过激活
  3. 生成注意力权重
  4. 用权重重新加权 x

代码如下:

python 复制代码
import torch
import torch.nn as nn


class AttentionGate(nn.Module):
    def __init__(self, x_channels, g_channels, inter_channels):
        super().__init__()

        self.theta_x = nn.Conv2d(x_channels, inter_channels, kernel_size=1)
        self.phi_g = nn.Conv2d(g_channels, inter_channels, kernel_size=1)

        self.psi = nn.Sequential(
            nn.Conv2d(inter_channels, 1, kernel_size=1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, g):
        theta_x = self.theta_x(x)
        phi_g = self.phi_g(g)

        attention = self.relu(theta_x + phi_g)
        attention = self.psi(attention)

        return x * attention

六、完整Attention UNet模型代码

python 复制代码
import torch
import torch.nn as nn


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)


class AttentionGate(nn.Module):
    def __init__(self, x_channels, g_channels, inter_channels):
        super().__init__()

        self.theta_x = nn.Conv2d(x_channels, inter_channels, 1)
        self.phi_g = nn.Conv2d(g_channels, inter_channels, 1)

        self.psi = nn.Sequential(
            nn.Conv2d(inter_channels, 1, 1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, g):
        attn = self.relu(self.theta_x(x) + self.phi_g(g))
        attn = self.psi(attn)
        return x * attn


class AttentionUNetDenoise(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.pool = nn.MaxPool2d(2)

        self.enc1 = ConvBlock(in_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)

        self.bottleneck = ConvBlock(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.att3 = AttentionGate(256, 256, 128)
        self.dec3 = ConvBlock(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.att2 = AttentionGate(128, 128, 64)
        self.dec2 = ConvBlock(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.att1 = AttentionGate(64, 64, 32)
        self.dec1 = ConvBlock(128, 64)

        self.out = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        b = self.bottleneck(self.pool(e3))

        d3 = self.up3(b)
        e3_att = self.att3(e3, d3)
        d3 = torch.cat([d3, e3_att], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        e2_att = self.att2(e2, d2)
        d2 = torch.cat([d2, e2_att], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        e1_att = self.att1(e1, d1)
        d1 = torch.cat([d1, e1_att], dim=1)
        d1 = self.dec1(d1)

        return self.out(d1)

七、训练代码

python 复制代码
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.attention_unet import AttentionUNetDenoise


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = DenoiseDataset("data/train")
    loader = DataLoader(dataset, batch_size=6, shuffle=True, num_workers=4)

    model = AttentionUNetDenoise().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    l1_loss = torch.nn.L1Loss()

    for epoch in range(1, 61):
        model.train()
        total_loss = 0

        for noisy, clean in loader:
            noisy = noisy.to(device)
            clean = clean.to(device)

            pred = model(noisy)
            loss = l1_loss(pred, clean)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch}, Loss: {avg_loss:.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"attention_unet_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

八、数据集代码

python 复制代码
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class DenoiseDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".jpeg", ".png"))
        ]

        self.patch_size = patch_size
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")

        w, h = img.size
        if w >= self.patch_size and h >= self.patch_size:
            x = random.randint(0, w - self.patch_size)
            y = random.randint(0, h - self.patch_size)
            img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
        else:
            img = img.resize((self.patch_size, self.patch_size))

        clean = self.to_tensor(img)

        sigma = random.choice([10, 15, 25, 35, 50])
        noise = torch.randn_like(clean) * sigma / 255.0

        noisy = torch.clamp(clean + noise, 0.0, 1.0)

        return noisy, clean

九、为什么Attention加在Skip Connection上?

这是很多人容易写错的地方。

Attention不是随便加在哪里都有效。

在图像恢复任务里,skip connection 传递的是浅层纹理信息,同时也可能传递噪声信息。

如果直接 concat:

python 复制代码
torch.cat([decoder_feature, encoder_feature], dim=1)

模型会把所有浅层信息都拿来用,包括噪声。

加上 Attention Gate 后,模型会先筛选:

python 复制代码
encoder_feature -> attention gate -> useful feature

这样可以减少噪声特征污染 decoder。


十、效果验证

实际实验中,Attention UNet相比普通UNet,提升主要体现在:

  • 文字边缘更清楚
  • 纹理区域更自然
  • 背景噪声控制更稳定
  • 过度平滑问题减轻

但注意:

Attention UNet不一定让PSNR大幅提升,但肉眼效果往往更好。

这是图像恢复任务里非常常见的现象。


十一、踩坑记录

坑1:Attention尺寸不一致

Attention Gate中的 x 和 g 尺寸必须一致。

如果不一致,可以使用插值:

python 复制代码
g = torch.nn.functional.interpolate(g, size=x.shape[2:], mode="bilinear", align_corners=False)

坑2:显存占用增加

Attention模块会增加计算量。

解决方式:

  • batch_size 减小
  • patch_size 从 128 开始
  • 不要一开始就用 256 或 512

坑3:注意力权重过强导致细节消失

如果 attention 过度抑制浅层信息,反而会丢细节。

可以改成残差形式:

python 复制代码
return x * attention + x

这在一些数据集上更稳。


十二、适合收藏总结

Attention UNet适合什么场景?

  • 文字图像去噪
  • 人脸图像去噪
  • 医学图像去噪
  • 细节区域很重要的图像恢复任务

避坑清单

  • Attention最好加在skip connection上
  • x和g尺寸必须一致
  • batch_size要控制
  • 不要盲目堆注意力模块
  • PSNR不是唯一标准,要看视觉效果

十三、优化建议

可以继续尝试:

  • Attention + ResUNet
  • SE注意力
  • CBAM注意力
  • 多尺度Attention
  • Transformer Attention

结尾总结

Attention UNet的核心价值不是"让模型更复杂",而是:

让模型知道图像中哪些区域更值得恢复。

在真实工程里,图像去噪并不是追求整张图平均变干净,而是要让关键区域更清楚。

Attention UNet正是朝这个方向迈出的一步。


下一篇预告

Pytorch图像去噪实战(五):FFDNet实战,用噪声图控制不同强度的去噪效果

相关推荐
颜酱1 小时前
LLM为核,上下文为限:拆解AI Agent生态的底层逻辑
前端·人工智能
2401_833033621 小时前
如何修复固定定位头部容器中悬浮下拉菜单的错位问题
jvm·数据库·python
熊猫钓鱼>_>1 小时前
当“虾”遇上“马”:QClaw 融合 Hermes 背后的智能体进化论
人工智能·ai·腾讯云·agent·openclaw·qclaw·hermes
深念Y1 小时前
Denuvo加密被全面攻破?聊聊D加密原理和这次的破解事件
人工智能·游戏·ai·逆向·虚拟机·虚拟·d加密
KKKlucifer1 小时前
日志审计与行为分析在安全服务中的应用实践
网络·人工智能·安全
SelectDB1 小时前
Doris & SelectDB for AI 实战:从基础 RAG 到知识图谱增强的完整实现
数据库·人工智能·数据分析
Agent产品评测局1 小时前
生产排期与MES/ERP系统打通,实操方法详解:2026企业级智能体与超自动化集成实战指南
运维·人工智能·ai·chatgpt·自动化
GitCode官方1 小时前
一声唤醒 万物响应|AtomGit 首款开源鸿蒙 AI 硬件「小鸿」发布会圆满落幕 定义智能交互新入口
人工智能·开源·harmonyos
互联网志1 小时前
打通转化通道 赋能产业发展——高校科技成果转化的现状与破局
大数据·人工智能·物联网