Pytorch图像去噪实战(四):Attention UNet图像去噪实战,让模型重点恢复边缘和纹理区域
一、问题场景:模型降噪了,但重点区域恢复不够好
在实际图像去噪项目中,我发现一个很常见的问题:
模型整体去噪效果不错,但关键区域恢复不够好。
比如:
- OCR图片中文字边缘变虚
- 人脸图像中眼睛、头发纹理被抹平
- 医学影像中细小结构不清晰
- 老照片修复中衣服纹理丢失
这类问题不是简单提高模型层数就能解决的。
我一开始尝试加深 UNet、换 ResUNet,但提升有限。
后来发现根因是:
模型没有显式区分"哪里重要"。
普通卷积网络会平等处理整张图像,但图像去噪里,不同区域的重要性是不一样的。
因此这一篇我们引入:Attention UNet。
二、为什么图像去噪需要注意力机制?
图像里不同区域的信息价值不同:
平坦区域
比如天空、墙面、背景。
这类区域主要目标是去掉噪声,保持平滑。
边缘区域
比如文字边缘、物体轮廓。
这类区域既要降噪,又不能模糊。
纹理区域
比如头发、布料、医学细节。
这类区域最容易被模型误认为噪声。
普通UNet通过 skip connection 传递浅层信息,但它不会判断哪些浅层信息更重要。
Attention模块的作用就是:
给重要区域更高权重,给无关区域更低权重。
三、Attention UNet核心思想
Attention UNet不是把注意力加在所有地方,而是通常加在 skip connection 上。
普通UNet:
text
encoder feature -> concat -> decoder
Attention UNet:
text
encoder feature -> attention gate -> concat -> decoder
这样做的好处是:
- 减少无关噪声特征传递
- 强化边缘和结构信息
- 提升细节恢复能力
四、工程目录结构
attention_unet_denoise/
├── data/
│ ├── train/
│ └── val/
├── models/
│ └── attention_unet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py
五、Attention Gate模块实现
Attention Gate的输入一般有两个:
- x:encoder传来的浅层特征
- g:decoder当前的引导特征
核心流程:
- 对 x 和 g 做通道映射
- 相加后经过激活
- 生成注意力权重
- 用权重重新加权 x
代码如下:
python
import torch
import torch.nn as nn
class AttentionGate(nn.Module):
def __init__(self, x_channels, g_channels, inter_channels):
super().__init__()
self.theta_x = nn.Conv2d(x_channels, inter_channels, kernel_size=1)
self.phi_g = nn.Conv2d(g_channels, inter_channels, kernel_size=1)
self.psi = nn.Sequential(
nn.Conv2d(inter_channels, 1, kernel_size=1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, g):
theta_x = self.theta_x(x)
phi_g = self.phi_g(g)
attention = self.relu(theta_x + phi_g)
attention = self.psi(attention)
return x * attention
六、完整Attention UNet模型代码
python
import torch
import torch.nn as nn
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class AttentionGate(nn.Module):
def __init__(self, x_channels, g_channels, inter_channels):
super().__init__()
self.theta_x = nn.Conv2d(x_channels, inter_channels, 1)
self.phi_g = nn.Conv2d(g_channels, inter_channels, 1)
self.psi = nn.Sequential(
nn.Conv2d(inter_channels, 1, 1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, g):
attn = self.relu(self.theta_x(x) + self.phi_g(g))
attn = self.psi(attn)
return x * attn
class AttentionUNetDenoise(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.pool = nn.MaxPool2d(2)
self.enc1 = ConvBlock(in_channels, 64)
self.enc2 = ConvBlock(64, 128)
self.enc3 = ConvBlock(128, 256)
self.bottleneck = ConvBlock(256, 512)
self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
self.att3 = AttentionGate(256, 256, 128)
self.dec3 = ConvBlock(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
self.att2 = AttentionGate(128, 128, 64)
self.dec2 = ConvBlock(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
self.att1 = AttentionGate(64, 64, 32)
self.dec1 = ConvBlock(128, 64)
self.out = nn.Conv2d(64, out_channels, 1)
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
b = self.bottleneck(self.pool(e3))
d3 = self.up3(b)
e3_att = self.att3(e3, d3)
d3 = torch.cat([d3, e3_att], dim=1)
d3 = self.dec3(d3)
d2 = self.up2(d3)
e2_att = self.att2(e2, d2)
d2 = torch.cat([d2, e2_att], dim=1)
d2 = self.dec2(d2)
d1 = self.up1(d2)
e1_att = self.att1(e1, d1)
d1 = torch.cat([d1, e1_att], dim=1)
d1 = self.dec1(d1)
return self.out(d1)
七、训练代码
python
import torch
from torch.utils.data import DataLoader
from dataset import DenoiseDataset
from models.attention_unet import AttentionUNetDenoise
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = DenoiseDataset("data/train")
loader = DataLoader(dataset, batch_size=6, shuffle=True, num_workers=4)
model = AttentionUNetDenoise().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
l1_loss = torch.nn.L1Loss()
for epoch in range(1, 61):
model.train()
total_loss = 0
for noisy, clean in loader:
noisy = noisy.to(device)
clean = clean.to(device)
pred = model(noisy)
loss = l1_loss(pred, clean)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch}, Loss: {avg_loss:.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"attention_unet_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
八、数据集代码
python
import os
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class DenoiseDataset(Dataset):
def __init__(self, root_dir, patch_size=128):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".jpeg", ".png"))
]
self.patch_size = patch_size
self.to_tensor = transforms.ToTensor()
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("L")
w, h = img.size
if w >= self.patch_size and h >= self.patch_size:
x = random.randint(0, w - self.patch_size)
y = random.randint(0, h - self.patch_size)
img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
else:
img = img.resize((self.patch_size, self.patch_size))
clean = self.to_tensor(img)
sigma = random.choice([10, 15, 25, 35, 50])
noise = torch.randn_like(clean) * sigma / 255.0
noisy = torch.clamp(clean + noise, 0.0, 1.0)
return noisy, clean
九、为什么Attention加在Skip Connection上?
这是很多人容易写错的地方。
Attention不是随便加在哪里都有效。
在图像恢复任务里,skip connection 传递的是浅层纹理信息,同时也可能传递噪声信息。
如果直接 concat:
python
torch.cat([decoder_feature, encoder_feature], dim=1)
模型会把所有浅层信息都拿来用,包括噪声。
加上 Attention Gate 后,模型会先筛选:
python
encoder_feature -> attention gate -> useful feature
这样可以减少噪声特征污染 decoder。
十、效果验证
实际实验中,Attention UNet相比普通UNet,提升主要体现在:
- 文字边缘更清楚
- 纹理区域更自然
- 背景噪声控制更稳定
- 过度平滑问题减轻
但注意:
Attention UNet不一定让PSNR大幅提升,但肉眼效果往往更好。
这是图像恢复任务里非常常见的现象。
十一、踩坑记录
坑1:Attention尺寸不一致
Attention Gate中的 x 和 g 尺寸必须一致。
如果不一致,可以使用插值:
python
g = torch.nn.functional.interpolate(g, size=x.shape[2:], mode="bilinear", align_corners=False)
坑2:显存占用增加
Attention模块会增加计算量。
解决方式:
- batch_size 减小
- patch_size 从 128 开始
- 不要一开始就用 256 或 512
坑3:注意力权重过强导致细节消失
如果 attention 过度抑制浅层信息,反而会丢细节。
可以改成残差形式:
python
return x * attention + x
这在一些数据集上更稳。
十二、适合收藏总结
Attention UNet适合什么场景?
- 文字图像去噪
- 人脸图像去噪
- 医学图像去噪
- 细节区域很重要的图像恢复任务
避坑清单
- Attention最好加在skip connection上
- x和g尺寸必须一致
- batch_size要控制
- 不要盲目堆注意力模块
- PSNR不是唯一标准,要看视觉效果
十三、优化建议
可以继续尝试:
- Attention + ResUNet
- SE注意力
- CBAM注意力
- 多尺度Attention
- Transformer Attention
结尾总结
Attention UNet的核心价值不是"让模型更复杂",而是:
让模型知道图像中哪些区域更值得恢复。
在真实工程里,图像去噪并不是追求整张图平均变干净,而是要让关键区域更清楚。
Attention UNet正是朝这个方向迈出的一步。
下一篇预告
Pytorch图像去噪实战(五):FFDNet实战,用噪声图控制不同强度的去噪效果