文章目录
环境配置
可参考这个:环境配置
数据集
可参考这个:数据集
代码
-
ASPP模块
多尺度空洞卷积捕获不同范围的上下文信息
适合遥感影像中不同大小的地物目标
包含全局平均池化捕获全局上下文
-
编码器-解码器结构
编码器: ResNet backbone提取多层次特征
解码器: 融合高层语义信息和低层空间细节
Deeplabv3+
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import config
class ASPP(nn.Module):
"""ASPP模块 - 多尺度空洞卷积"""
def __init__(self, in_channels, out_channels=256):
super(ASPP, self).__init__()
# 1x1卷积
self.conv_1x1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# 3x3空洞卷积,rate=6
self.conv_3x3_1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# 3x3空洞卷积,rate=12
self.conv_3x3_2 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# 3x3空洞卷积,rate=18
self.conv_3x3_3 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# 全局平均池化
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# 输出卷积
self.conv_out = nn.Sequential(
nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5)
)
def forward(self, x):
x1 = self.conv_1x1(x)
x2 = self.conv_3x3_1(x)
x3 = self.conv_3x3_2(x)
x4 = self.conv_3x3_3(x)
# 全局平均池化并上采样到原始尺寸
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)
# 拼接所有特征
x = torch.cat([x1, x2, x3, x4, x5], dim=1)
x = self.conv_out(x)
return x
class DeepLabV3Plus(nn.Module):
"""Deeplabv3+模型 - 适合遥感影像土地利用分类"""
def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):
super(DeepLabV3Plus, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
# 选择backbone
if backbone == 'resnet50':
self.backbone = models.resnet50(pretrained=True)
low_level_channels = 256
high_level_channels = 2048
elif backbone == 'resnet101':
self.backbone = models.resnet101(pretrained=True)
low_level_channels = 256
high_level_channels = 2048
else: # resnet34
self.backbone = models.resnet34(pretrained=True)
low_level_channels = 64
high_level_channels = 512
# 修改第一层卷积以适应多波段输入
if n_channels != 3:
self.backbone.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
# 移除最后的全连接层和平均池化层
self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
# ASPP模块
self.aspp = ASPP(high_level_channels, 256)
# 低层特征处理
self.low_level_conv = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True)
)
# 解码器
self.decoder_conv = nn.Sequential(
nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.1)
)
# 分类器
self.classifier = nn.Conv2d(256, n_classes, 1)
# 初始化权重
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 获取输入尺寸
input_size = x.size()[2:]
# Backbone特征提取
features = self.backbone(x)
# 低层特征(浅层特征,保留更多空间信息)
low_level_features = None
if hasattr(self.backbone, 'layer1'):
# 对于标准ResNet
low_level_features = self.backbone.layer1(x)
else:
# 对于Sequential包装的backbone,需要手动获取中间特征
# 这里简化处理,实际使用时可能需要根据具体backbone结构调整
low_level_features = features
# 高层特征通过ASPP
high_level_features = self.aspp(features)
# 上采样高层特征
high_level_features = F.interpolate(
high_level_features,
scale_factor=4,
mode='bilinear',
align_corners=True
)
# 处理低层特征
low_level_features = self.low_level_conv(low_level_features)
# 拼接高低层特征
x = torch.cat([high_level_features, low_level_features], dim=1)
# 解码器卷积
x = self.decoder_conv(x)
# 分类
x = self.classifier(x)
# 上采样到原始输入尺寸
x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
return x
class DeepLabV3PlusWithAuxiliary(DeepLabV3Plus):
"""带辅助损失的Deeplabv3+,用于训练时提升性能"""
def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):
super(DeepLabV3PlusWithAuxiliary, self).__init__(n_channels, n_classes, backbone)
# 辅助分类器
self.aux_classifier = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(256, n_classes, 1)
)
def forward(self, x):
input_size = x.size()[2:]
features = self.backbone(x)
# 低层特征
low_level_features = self.backbone.layer1(x) if hasattr(self.backbone, 'layer1') else features
# 高层特征通过ASPP
high_level_features = self.aspp(features)
# 辅助输出
aux_output = self.aux_classifier(high_level_features)
aux_output = F.interpolate(aux_output, size=input_size, mode='bilinear', align_corners=True)
# 主分支
high_level_features = F.interpolate(high_level_features, scale_factor=4, mode='bilinear', align_corners=True)
low_level_features = self.low_level_conv(low_level_features)
x = torch.cat([high_level_features, low_level_features], dim=1)
x = self.decoder_conv(x)
x = self.classifier(x)
x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
if self.training:
return x, aux_output
else:
return x
# 便捷创建函数
def create_deeplabv3plus(model_type='standard', **kwargs):
"""
创建Deeplabv3+模型
Args:
model_type: 'standard' 或 'with_auxiliary'
**kwargs: 模型参数
"""
if model_type == 'with_auxiliary':
return DeepLabV3PlusWithAuxiliary(**kwargs)
else:
return DeepLabV3Plus(**kwargs)
训练时参考这个
python
# 创建模型
model = create_deeplabv3plus(
n_channels=config.NUM_BANDS,
n_classes=config.NUM_CLASSES,
backbone='resnet50'
)
# 训练时(如果使用辅助损失版本)
if model.training:
output, aux_output = model(x)
loss = main_loss(output, target) + 0.4 * aux_loss(aux_output, target)
else:
output = model(x)
推理时模型自动返回最终输出(logits),与原有评估 / 预测流程兼容,无需修改eval.py和predict.py。