使用Deeplabv3+进行遥感影像土地利用分类

文章目录

环境配置

可参考这个:环境配置

数据集

可参考这个:数据集

代码

  1. ASPP模块

    多尺度空洞卷积捕获不同范围的上下文信息

    适合遥感影像中不同大小的地物目标

    包含全局平均池化捕获全局上下文

  2. 编码器-解码器结构

    编码器: ResNet backbone提取多层次特征

    解码器: 融合高层语义信息和低层空间细节

Deeplabv3+

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import config

class ASPP(nn.Module):
    """ASPP模块 - 多尺度空洞卷积"""
    def __init__(self, in_channels, out_channels=256):
        super(ASPP, self).__init__()
        
        # 1x1卷积
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 3x3空洞卷积,rate=6
        self.conv_3x3_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 3x3空洞卷积,rate=12
        self.conv_3x3_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 3x3空洞卷积,rate=18
        self.conv_3x3_3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 全局平均池化
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 输出卷积
        self.conv_out = nn.Sequential(
            nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        x1 = self.conv_1x1(x)
        x2 = self.conv_3x3_1(x)
        x3 = self.conv_3x3_2(x)
        x4 = self.conv_3x3_3(x)
        
        # 全局平均池化并上采样到原始尺寸
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)
        
        # 拼接所有特征
        x = torch.cat([x1, x2, x3, x4, x5], dim=1)
        x = self.conv_out(x)
        
        return x


class DeepLabV3Plus(nn.Module):
    """Deeplabv3+模型 - 适合遥感影像土地利用分类"""
    def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):
        super(DeepLabV3Plus, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        # 选择backbone
        if backbone == 'resnet50':
            self.backbone = models.resnet50(pretrained=True)
            low_level_channels = 256
            high_level_channels = 2048
        elif backbone == 'resnet101':
            self.backbone = models.resnet101(pretrained=True)
            low_level_channels = 256
            high_level_channels = 2048
        else:  # resnet34
            self.backbone = models.resnet34(pretrained=True)
            low_level_channels = 64
            high_level_channels = 512
        
        # 修改第一层卷积以适应多波段输入
        if n_channels != 3:
            self.backbone.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # 移除最后的全连接层和平均池化层
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # ASPP模块
        self.aspp = ASPP(high_level_channels, 256)
        
        # 低层特征处理
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # 解码器
        self.decoder_conv = nn.Sequential(
            nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        
        # 分类器
        self.classifier = nn.Conv2d(256, n_classes, 1)
        
        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # 获取输入尺寸
        input_size = x.size()[2:]
        
        # Backbone特征提取
        features = self.backbone(x)
        
        # 低层特征(浅层特征,保留更多空间信息)
        low_level_features = None
        if hasattr(self.backbone, 'layer1'):
            # 对于标准ResNet
            low_level_features = self.backbone.layer1(x)
        else:
            # 对于Sequential包装的backbone,需要手动获取中间特征
            # 这里简化处理,实际使用时可能需要根据具体backbone结构调整
            low_level_features = features
        
        # 高层特征通过ASPP
        high_level_features = self.aspp(features)
        
        # 上采样高层特征
        high_level_features = F.interpolate(
            high_level_features, 
            scale_factor=4, 
            mode='bilinear', 
            align_corners=True
        )
        
        # 处理低层特征
        low_level_features = self.low_level_conv(low_level_features)
        
        # 拼接高低层特征
        x = torch.cat([high_level_features, low_level_features], dim=1)
        
        # 解码器卷积
        x = self.decoder_conv(x)
        
        # 分类
        x = self.classifier(x)
        
        # 上采样到原始输入尺寸
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
        
        return x


class DeepLabV3PlusWithAuxiliary(DeepLabV3Plus):
    """带辅助损失的Deeplabv3+,用于训练时提升性能"""
    def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):
        super(DeepLabV3PlusWithAuxiliary, self).__init__(n_channels, n_classes, backbone)
        
        # 辅助分类器
        self.aux_classifier = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(256, n_classes, 1)
        )

    def forward(self, x):
        input_size = x.size()[2:]
        features = self.backbone(x)
        
        # 低层特征
        low_level_features = self.backbone.layer1(x) if hasattr(self.backbone, 'layer1') else features
        
        # 高层特征通过ASPP
        high_level_features = self.aspp(features)
        
        # 辅助输出
        aux_output = self.aux_classifier(high_level_features)
        aux_output = F.interpolate(aux_output, size=input_size, mode='bilinear', align_corners=True)
        
        # 主分支
        high_level_features = F.interpolate(high_level_features, scale_factor=4, mode='bilinear', align_corners=True)
        low_level_features = self.low_level_conv(low_level_features)
        x = torch.cat([high_level_features, low_level_features], dim=1)
        x = self.decoder_conv(x)
        x = self.classifier(x)
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
        
        if self.training:
            return x, aux_output
        else:
            return x


# 便捷创建函数
def create_deeplabv3plus(model_type='standard', **kwargs):
    """
    创建Deeplabv3+模型
    Args:
        model_type: 'standard' 或 'with_auxiliary'
        **kwargs: 模型参数
    """
    if model_type == 'with_auxiliary':
        return DeepLabV3PlusWithAuxiliary(**kwargs)
    else:
        return DeepLabV3Plus(**kwargs)

训练时参考这个

python 复制代码
# 创建模型
model = create_deeplabv3plus(
    n_channels=config.NUM_BANDS,
    n_classes=config.NUM_CLASSES,
    backbone='resnet50'
)

# 训练时(如果使用辅助损失版本)
if model.training:
    output, aux_output = model(x)
    loss = main_loss(output, target) + 0.4 * aux_loss(aux_output, target)
else:
    output = model(x)

推理时模型自动返回最终输出(logits),与原有评估 / 预测流程兼容,无需修改eval.py和predict.py

相关推荐
java1234_小锋5 小时前
PyTorch2 Python深度学习 - 自动微分(Autograd)与梯度优化
开发语言·python·深度学习·pytorch2
AI模块工坊5 小时前
CVPR 即插即用 | PConv:重新定义高效卷积,一个让模型“跑”得更快、更省的新范式
人工智能·深度学习·计算机视觉·transformer
java1234_小锋5 小时前
PyTorch2 Python深度学习 - 简介以及入门
python·深度学习·pytorch2
麻雀无能为力9 小时前
深度学习计算
人工智能·深度学习
机器学习之心HML11 小时前
TCN-Transformer-LSTM多特征分类预测Matlab实现
分类·lstm·transformer
北诺南兮12 小时前
大模型算法面试笔记——多头潜在注意力(MLA)
笔记·深度学习·算法
Fuxiao___12 小时前
OpenVLA-OFT+ 在真实世界 ALOHA 机器人任务中的应用
人工智能·深度学习·计算机视觉
Dev7z12 小时前
基于Swin Transformer的肝脏肿瘤MRI图像分类与诊断系统
人工智能·深度学习·transformer
java1234_小锋13 小时前
PyTorch2 Python深度学习 - 张量(Tensor)的定义与操作
开发语言·python·深度学习·pytorch2