前言
上期第10期,我们吃透了CNN卷积、池化、激活函数三大底层逻辑,厘清了新手最大误区:通用RGB图像模型,不可直接照搬用于遥感俯拍影像训练。
基础堆叠CNN网络极易出现两大遥感训练通病:网络越深梯度消失、遥感地物特征流失,小样本遥感数据集训练直接过拟合、精度崩盘。
由此衍生出两大遥感领域封神分类骨干网络:ResNet残差网络、DenseNet密集连接网络 。截至目前,二者依旧是遥感土地利用分类、场景分类、长时序地物监测的论文基线首选模型,适配绝大多数本科毕设、小范围工程项目、期刊对比实验。
本期结合遥感专属特性,通俗拆解两大网络架构、落地遥感场景选型逻辑、手把手完成数据集加载→模型训练→超参调优→结果后处理→精度优化全流程,附带遥感专属训练代码,解决训练不收敛、Kappa偏低、过拟合、地物混分全痛点。
一、遥感两大分类任务区分+精准模型选型
很多新手混淆遥感分类概念,导致模型选错、标注白费。开篇先界定行业两大标准分类任务,精准匹配ResNet/DenseNet适用场景。
1.1 遥感两大图像分类任务定义
① 遥感场景分类(Scene Classification)
以整张512/256遥感切片为单位,判定整片区域属性,一图一类。
✅ 适用场景:城区/林地/水域/耕地全域研判、卫星影像归档、灾害全域判别
② 土地利用像素分类(Land Use Classification)
依托骨干网络提取特征,配合解码头,完成像素级土地覆被判别,对接国土三调分类标准。
✅ 适用场景:土地利用普查、生态覆被监测、耕地非农化研判、毕设主流选题
1.2 ResNet VS DenseNet 遥感场景选型对照表
| 网络模型 | 遥感优势 | 短板局限 | 最优适用场景 |
|---|---|---|---|
| ResNet系列 | 残差旁路防梯度消失、算力开销低、遥感预训练权重丰富、调参简单、收敛速度快 | 浅层特征复用率低,细碎地物表现力一般 | 常规土地分类、大数据集、快速实验、期刊基线、新手首选(ResNet18/34) |
| DenseNet系列 | 全层级特征复用、边缘纹理保留极强、抗云雾干扰、小样本精度更高 | 参数量大、显存占用高、训练耗时久、调参门槛高 | 云雾多发区域、细碎地物分类、小样本数据集、高精度科研实验 |
新手选型定论:零基础、显存一般、赶毕设进度 → 优先ResNet18;小样本、影像多云雾、追求高精度 → 选用DenseNet121
二、两大骨干网络通俗结构解读
基于上期CNN基础,零复杂公式,拆解网络核心改良逻辑,弄懂为什么适配遥感影像。
2.1 ResNet残差网络:解决深层网络特征退化
核心痛点:普通CNN堆叠层数越多,遥感河道、田块、小路浅层特征越容易丢失,梯度反向传播失效,模型精度反向下降。
核心创新:残差旁路连接
新增直连旁路,跳过卷积层直接传递原始遥感浅层纹理、边缘特征,做到:卷积学习深层光谱特征、旁路保留浅层地物轮廓,完美适配遥感多尺度地物特征。
遥感常用版本:ResNet18(轻量)、ResNet34(均衡)、ResNet50(高精度)
2.2 DenseNet密集连接网络:最大化复用遥感特征
核心创新:层与层全连接
每一层卷积特征,全部传递给后续所有网络层,最大限度保留遥感微弱特征:山间小路、零散光伏、田间裸土等易丢失小目标特征,天然适配遥感同谱异物、小目标密集难题。
遥感常用版本:DenseNet121(性价比最高)
🖼️配图1:ResNet+DenseNet简化结构对比图
科技蓝扁平化双栏结构图,左侧ResNet残差跳转结构标注、右侧DenseNet密集连线结构标注,中文标注特征流向、旁路作用,适配公众号排版。
三、遥感数据集专属训练全流程+参数调优
3.1 遥感分类数据集标准格式
沿用专栏统一数据集格式,适配往期切片数据,无需重新标注:
bash
rs_class_dataset/
├─train/ #训练集切片
├─val/ #验证集切片
└─test/ #测试集切片
3.2 环境前置配置(国内清华镜像零报错)
bash
pip install torch torchvision pillow numpy rasterio -i https://pypi.tuna.tsinghua.edu.cn/simple/
3.3 ResNet遥感专属完整训练代码(修改通道即可运行)
适配3波段RGB/4波段多光谱遥感影像,修改首层卷积通道,规避通用模型适配bug,自带日志保存、精度计算、模型保存功能
python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import os
# =====================遥感专属超参配置区【直接修改】=====================
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
IN_CHANNEL = 4 # 3=RGB影像 4=哨兵多光谱影像
CLASS_NUM = 5 # 地物类别:水体/建筑/植被/道路/耕地
BATCH_SIZE = 16
EPOCHS = 50
LR = 1e-3 # 遥感分类最优学习率
DATA_PATH = "./rs_class_dataset"
# =====================================================================
# 遥感影像专属数据增强(拒绝无效翻转)
train_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
])
val_transform = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor(),
])
# 加载数据集
train_dataset = datasets.ImageFolder(os.path.join(DATA_PATH,"train"),transform=train_transform)
val_dataset = datasets.ImageFolder(os.path.join(DATA_PATH,"val"),transform=val_transform)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=False)
# 加载ResNet18,修改首层适配多波段遥感影像
model = models.resnet18(pretrained=False)
# 核心修改:适配4波段输入,解决通用模型通道报错
model.conv1 = nn.Conv2d(IN_CHANNEL,64,kernel_size=7,stride=2,padding=3,bias=False)
model.fc = nn.Linear(model.fc.in_features,CLASS_NUM)
model = model.to(DEVICE)
# 遥感专属损失函数+优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(),lr=LR,weight_decay=5e-4) #权重衰减防遥感过拟合
# 训练循环
def train_one_epoch():
model.train()
total_loss = 0
for img,label in train_loader:
img,label = img.to(DEVICE),label.to(DEVICE)
optimizer.zero_grad()
pred = model(img)
loss = criterion(pred,label)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss/len(train_loader)
# 验证评估
def val_one_epoch():
model.eval()
correct = 0
with torch.no_grad():
for img,label in val_loader:
img,label = img.to(DEVICE),label.to(DEVICE)
pred = model(img)
correct += (torch.argmax(pred,dim=1)==label).sum().item()
acc = correct/len(val_dataset)
return acc
# 启动训练
if __name__ == "__main__":
best_acc = 0
for epoch in range(EPOCHS):
train_loss = train_one_epoch()
val_acc = val_one_epoch()
print(f"第{epoch+1}轮|训练损失:{train_loss:.4f}|验证精度:{val_acc:.4f}")
# 保存最优权重
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(),"./rs_resnet_best.pth")
print("✅ 训练结束,最优模型已保存!")
3.4 遥感模型黄金调参手册(解决不收敛/精度低)
-
学习率LR:遥感固定1e-3 ~ 5e-4,大于1e-3梯度爆炸,小于1e-4收敛极慢
-
权重衰减weight_decay:必须开启5e-4,抑制遥感纹理噪声过拟合
-
批次BatchSize:2060/3060显卡设8-16,4090显卡可设32
-
预训练权重:放弃ImageNet权重,选用BigEarthNet遥感预训练权重
-
早停机制:验证集精度10轮不上涨,直接停止训练,防止过拟合
🖼️配图2:模型训练损失精度曲线图
标准双轴折线图:蓝色训练损失递减曲线、橙色验证精度上升曲线,标注收敛节点、过拟合拐点,直观判断训练状态。
四、分类结果后处理+落地精度优化方案
模型输出原始分类图普遍存在椒盐噪点、细碎错分斑块,遥感项目、论文出图必须做后处理,大幅提升OA、Kappa系数。
4.1 遥感三大后处理手段(行业通用)
-
形态学滤波去噪:剔除孤立单点错分像素,消除遥感椒盐噪声
-
连通域面积阈值筛选:剔除小于阈值的细碎斑块,贴合地物连片分布特征
-
同类别邻域平滑:优化地物边界锯齿,贴合目视解译边界标准
4.2 简易OpenCV后处理优化代码
python
import cv2
import numpy as np
# 遥感分类图后处理去噪优化
def rs_post_process(pred_mask):
# 形态学开运算:去除小白点噪点
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(3,3))
new_mask = cv2.morphologyEx(pred_mask,cv2.MORPH_OPEN,kernel)
# 连通域剔除细碎斑块
contours,_ = cv2.findContours(new_mask,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
for cnt in contours:
if cv2.contourArea(cnt) < 20: #小于20像素斑块直接剔除
cv2.fillPoly(new_mask,[cnt],0)
return new_mask
4.3 精度低效万能优化步骤
-
清洗数据集:剔除云雾、畸变无效切片,均衡各类样本数量
-
更换主干:细碎地物精度低,ResNet切换DenseNet
-
特征补强:叠加NDVI/NDWI指数特征,辅助模型区分水体植被
-
后处理降噪:必加形态学滤波,Kappa普遍提升0.03-0.08
🖼️配图3:分类优化前后对比效果图
左右分栏对比图:左图原始模型分类图(噪点多、边界锯齿、细碎错分);右图后处理优化效果图(边界平滑、噪点清零、地物规整),附带OA/Kappa数值对比标注。
✅ 本期全文核心总结
-
ResNet靠残差旁路防梯度消失,适合大数据、快速分类;DenseNet全特征复用,适合小样本、云雾遥感影像
-
遥感训练必改首层卷积通道,禁止直接使用RGB通用预训练权重
-
遥感专属超参:学习率1e-3、开启权重衰减、小batchsize稳定训练
-
分类后处理是提分关键,滤波+连通域筛选,低成本提升分类精度
📌 下期预告
遥感语义分割开山模型:U-Net结构详解+遥感像素级分割从零训练,适配512大图分割,打通精细化地物提取全流程!