使用Deeplabv3+进行遥感影像土地利用分类

文章目录

环境配置

可参考这个:环境配置

数据集

可参考这个:数据集

代码

  1. ASPP模块

    多尺度空洞卷积捕获不同范围的上下文信息

    适合遥感影像中不同大小的地物目标

    包含全局平均池化捕获全局上下文

  2. 编码器-解码器结构

    编码器: ResNet backbone提取多层次特征

    解码器: 融合高层语义信息和低层空间细节

Deeplabv3+

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import config

class ASPP(nn.Module):
    """ASPP模块 - 多尺度空洞卷积"""
    def __init__(self, in_channels, out_channels=256):
        super(ASPP, self).__init__()
        
        # 1x1卷积
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 3x3空洞卷积,rate=6
        self.conv_3x3_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 3x3空洞卷积,rate=12
        self.conv_3x3_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 3x3空洞卷积,rate=18
        self.conv_3x3_3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 全局平均池化
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        # 输出卷积
        self.conv_out = nn.Sequential(
            nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        x1 = self.conv_1x1(x)
        x2 = self.conv_3x3_1(x)
        x3 = self.conv_3x3_2(x)
        x4 = self.conv_3x3_3(x)
        
        # 全局平均池化并上采样到原始尺寸
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)
        
        # 拼接所有特征
        x = torch.cat([x1, x2, x3, x4, x5], dim=1)
        x = self.conv_out(x)
        
        return x


class DeepLabV3Plus(nn.Module):
    """Deeplabv3+模型 - 适合遥感影像土地利用分类"""
    def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):
        super(DeepLabV3Plus, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        # 选择backbone
        if backbone == 'resnet50':
            self.backbone = models.resnet50(pretrained=True)
            low_level_channels = 256
            high_level_channels = 2048
        elif backbone == 'resnet101':
            self.backbone = models.resnet101(pretrained=True)
            low_level_channels = 256
            high_level_channels = 2048
        else:  # resnet34
            self.backbone = models.resnet34(pretrained=True)
            low_level_channels = 64
            high_level_channels = 512
        
        # 修改第一层卷积以适应多波段输入
        if n_channels != 3:
            self.backbone.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # 移除最后的全连接层和平均池化层
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # ASPP模块
        self.aspp = ASPP(high_level_channels, 256)
        
        # 低层特征处理
        self.low_level_conv = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        
        # 解码器
        self.decoder_conv = nn.Sequential(
            nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        
        # 分类器
        self.classifier = nn.Conv2d(256, n_classes, 1)
        
        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # 获取输入尺寸
        input_size = x.size()[2:]
        
        # Backbone特征提取
        features = self.backbone(x)
        
        # 低层特征(浅层特征,保留更多空间信息)
        low_level_features = None
        if hasattr(self.backbone, 'layer1'):
            # 对于标准ResNet
            low_level_features = self.backbone.layer1(x)
        else:
            # 对于Sequential包装的backbone,需要手动获取中间特征
            # 这里简化处理,实际使用时可能需要根据具体backbone结构调整
            low_level_features = features
        
        # 高层特征通过ASPP
        high_level_features = self.aspp(features)
        
        # 上采样高层特征
        high_level_features = F.interpolate(
            high_level_features, 
            scale_factor=4, 
            mode='bilinear', 
            align_corners=True
        )
        
        # 处理低层特征
        low_level_features = self.low_level_conv(low_level_features)
        
        # 拼接高低层特征
        x = torch.cat([high_level_features, low_level_features], dim=1)
        
        # 解码器卷积
        x = self.decoder_conv(x)
        
        # 分类
        x = self.classifier(x)
        
        # 上采样到原始输入尺寸
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
        
        return x


class DeepLabV3PlusWithAuxiliary(DeepLabV3Plus):
    """带辅助损失的Deeplabv3+,用于训练时提升性能"""
    def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):
        super(DeepLabV3PlusWithAuxiliary, self).__init__(n_channels, n_classes, backbone)
        
        # 辅助分类器
        self.aux_classifier = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(256, n_classes, 1)
        )

    def forward(self, x):
        input_size = x.size()[2:]
        features = self.backbone(x)
        
        # 低层特征
        low_level_features = self.backbone.layer1(x) if hasattr(self.backbone, 'layer1') else features
        
        # 高层特征通过ASPP
        high_level_features = self.aspp(features)
        
        # 辅助输出
        aux_output = self.aux_classifier(high_level_features)
        aux_output = F.interpolate(aux_output, size=input_size, mode='bilinear', align_corners=True)
        
        # 主分支
        high_level_features = F.interpolate(high_level_features, scale_factor=4, mode='bilinear', align_corners=True)
        low_level_features = self.low_level_conv(low_level_features)
        x = torch.cat([high_level_features, low_level_features], dim=1)
        x = self.decoder_conv(x)
        x = self.classifier(x)
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
        
        if self.training:
            return x, aux_output
        else:
            return x


# 便捷创建函数
def create_deeplabv3plus(model_type='standard', **kwargs):
    """
    创建Deeplabv3+模型
    Args:
        model_type: 'standard' 或 'with_auxiliary'
        **kwargs: 模型参数
    """
    if model_type == 'with_auxiliary':
        return DeepLabV3PlusWithAuxiliary(**kwargs)
    else:
        return DeepLabV3Plus(**kwargs)

训练时参考这个

python 复制代码
# 创建模型
model = create_deeplabv3plus(
    n_channels=config.NUM_BANDS,
    n_classes=config.NUM_CLASSES,
    backbone='resnet50'
)

# 训练时(如果使用辅助损失版本)
if model.training:
    output, aux_output = model(x)
    loss = main_loss(output, target) + 0.4 * aux_loss(aux_output, target)
else:
    output = model(x)

推理时模型自动返回最终输出(logits),与原有评估 / 预测流程兼容,无需修改eval.py和predict.py

相关推荐
AI即插即用21 小时前
即插即用系列 | 2024 SOTA LAM-YOLO : 无人机小目标检测模型
pytorch·深度学习·yolo·目标检测·计算机视觉·视觉检测·无人机
金融小师妹1 天前
基于机器学习与深度强化学习:非农数据触发AI多因子模型预警!12月降息预期骤降的货币政策预测
大数据·人工智能·深度学习·1024程序员节
brave and determined1 天前
可编程逻辑器件学习(day29):Verilog HDL可综合代码设计规范与实践指南
深度学习·fpga开发·verilog·fpga·设计规范·硬件编程·嵌入式设计
算法与编程之美1 天前
提升minist的准确率并探索分类指标Precision,Recall,F1-Score和Accuracy
人工智能·算法·机器学习·分类·数据挖掘
大雷神1 天前
HarmonyOS 横竖屏切换与响应式布局实战指南
python·深度学习·harmonyos
青瓷程序设计1 天前
水果识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
AI模块工坊1 天前
CVPR 即插即用 | 当RetNet遇见ViT:一场来自曼哈顿的注意力革命,中科院刷新SOTA性能榜!
人工智能·深度学习·计算机视觉·transformer
强化学习与机器人控制仿真1 天前
Meta 最新开源 SAM 3 图像视频可提示分割模型
人工智能·深度学习·神经网络·opencv·目标检测·计算机视觉·目标跟踪
长不大的蜡笔小新1 天前
从0到1学AlexNet:用经典网络搞定花分类任务
图像处理·深度学习·机器学习
WWZZ20251 天前
快速上手大模型:深度学习5(实践:过、欠拟合)
人工智能·深度学习·神经网络·算法·机器人·大模型·具身智能