IS-Net 教程:基于 PyTorch 的图像分割网络
IS-Net
(Image Structure Network)是 DIS 项目 中的核心模块之一,用于进行复杂的图像结构化任务,尤其在图像分割、图像修复、去噪等任务中表现优异。本教程将介绍如何在 PyTorch 中使用 IS-Net 进行图像分割任务,并展示如何运行预训练模型和自定义数据集进行训练。
1. IS-Net 概述
IS-Net
是一个基于深度学习的图像分割网络,专注于图像中结构化信息的重建。它的网络结构类似于 UNet ,使用卷积操作来提取图像特征,并通过下采样和上采样层逐步进行分割任务。IS-Net
主要应用于以下任务:
- 图像分割:将图像中的每个像素分类为前景或背景,生成分割图像。
- 图像修复和去噪:利用图像的局部和全局结构信息来修复或去除噪声。
2. 环境设置
在使用 IS-Net
之前,我们需要确保安装了 PyTorch 以及项目依赖项。
2.1 安装 PyTorch 和依赖项
首先,确保你已经安装了 PyTorch 和相关的依赖项:
bash
pip install torch torchvision
2.2 克隆 IS-Net 代码库
你可以从 GitHub 克隆 DIS
仓库,它包含 IS-Net
模块。这里我们假设你已经将项目克隆到了本地:
bash
git clone https://github.com/xuebinqin/DIS.git
cd DIS
3. IS-Net 网络结构
IS-Net
的核心思想源自 UNet 的编码器-解码器架构。网络首先通过编码器部分提取图像的多尺度特征,然后通过解码器部分逐步恢复原始图像大小,同时生成结构化的分割结果。
3.1 IS-Net 模型定义
IS-Net 的具体网络结构定义在 DIS
项目的 models/
目录下。为了演示,我们简化了网络结构的定义:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义 IS-Net 的基础结构,类似于 UNet
class ISNet(nn.Module):
def __init__(self):
super(ISNet, self).__init__()
# 编码器部分(下采样)
self.encoder1 = self.double_conv(1, 64)
self.encoder2 = self.double_conv(64, 128)
self.encoder3 = self.double_conv(128, 256)
self.encoder4 = self.double_conv(256, 512)
# 中间部分
self.middle = self.double_conv(512, 1024)
# 解码器部分(上采样)
self.upconv4 = self.up_conv(1024, 512)
self.decoder4 = self.double_conv(1024, 512)
self.upconv3 = self.up_conv(512, 256)
self.decoder3 = self.double_conv(512, 256)
self.upconv2 = self.up_conv(256, 128)
self.decoder2 = self.double_conv(256, 128)
self.upconv1 = self.up_conv(128, 64)
self.decoder1 = self.double_conv(128, 64)
# 最后的分类层(输出二分类结果)
self.final = nn.Conv2d(64, 1, kernel_size=1)
def double_conv(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
def up_conv(self, in_channels, out_channels):
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
# 编码器
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
# 中间部分
middle = self.middle(F.max_pool2d(e4, 2))
# 解码器
d4 = self.upconv4(middle)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
return torch.sigmoid(self.final(d1))
# 创建 ISNet 模型实例
model = ISNet()
3.2 说明
- 双卷积层(double_conv) :这是每个卷积块的基础构造,使用两个 3x3 卷积核,并使用
ReLU
激活函数。 - 上采样层(up_conv):用于逐步恢复图像的原始尺寸。
- 最终输出层(final):通过一个 1x1 卷积层将网络输出的特征映射到所需的分割结果(通常是二分类的概率图)。
4. 使用预训练模型进行推理
你可以下载预训练的 IS-Net 模型并使用其进行推理任务。
4.1 下载预训练模型
首先,下载预训练的 IS-Net 模型,并将其放置到合适的目录下,例如 ./checkpoints/
:
bash
wget https://path_to_pretrained_model/isnet_pretrained.pth
4.2 加载预训练模型
使用 PyTorch 加载预训练模型的权重,并对图像进行分割:
python
# 加载预训练模型权重
model.load_state_dict(torch.load('./checkpoints/isnet_pretrained.pth', map_location=torch.device('cpu')))
model.eval() # 切换到评估模式
# 推理图像分割
from PIL import Image
import torchvision.transforms as transforms
# 图像预处理
transform = transforms.Compose([
transforms.Resize((256, 256)), # 调整输入图像大小
transforms.ToTensor() # 转换为张量
])
# 加载图像
image = Image.open('./input_image.jpg').convert('L') # 转换为灰度图像
input_tensor = transform(image).unsqueeze(0) # 增加 batch 维度
# 执行推理
with torch.no_grad():
output = model(input_tensor)
# 将分割结果转换为可视化格式
output_image = output.squeeze().cpu().numpy()
output_image = (output_image > 0.5).astype('uint8') # 二值化处理
# 保存结果
import matplotlib.pyplot as plt
plt.imshow(output_image, cmap='gray')
plt.savefig('./output_image.png')
5. 在自定义数据集上训练 IS-Net
如果你有自定义数据集,想在上面训练 IS-Net 模型,可以按照以下步骤操作。
5.1 准备数据集
确保你的数据集包含输入图像和对应的分割标签。可以将数据集组织为以下结构:
data/
├── train/
│ ├── images/
│ └── masks/
├── val/
│ ├── images/
│ └── masks/
5.2 编写数据加载器
使用 PyTorch 的 Dataset
类自定义数据加载器。
python
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomSegmentationDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.image_list = os.listdir(image_dir)
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.image_list[idx])
mask_path = os.path.join(self.mask_dir, self.image_list[idx]) # 假设掩码与图像名称相同
image = Image.open(img_path).convert('L')
mask = Image.open(mask_path).convert('L')
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask
5.3 训练 IS-Net 模型
使用定义好的数据集类和 PyTorch 的 DataLoader
进行训练。
python
from torch.utils.data import DataLoader
# 数据预处理
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
# 创建数据集和数据加载器
train_dataset = CustomSegmentationDataset('./data/train/images', './data/train/masks', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()
# 训练模型
model.train()
for epoch in range(5): # 假设训练 5 个 epoch
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')
6. 总结
通过本教程,我们学习了如何使用 IS-Net 进行图像分割任务,包括加载预训练模型和在自定义数据集上进行训练。IS-Net
作为一种强大的图像分割工具,可以用于多种图像结构重建任务,如去噪、图像修复等。