引言
在遥感与地理信息领域,卫星图像分类是核心任务之一,广泛应用于土地利用监测、灾害评估、城市规划等场景。Kaggle作为全球数据科学竞赛平台,提供了丰富的卫星图像数据集(如https://www.kaggle.com/c/planet-understanding-the-amazon-from-space),是学习深度学习实战的优质资源。本文将以Kaggle卫星图像分类比赛为背景,结合深度学习技术,详细解析从数据预处理到模型调优的全流程,并通过代码案例深入探讨关键技巧。
一、核心概念与应用场景
卫星图像分类指通过算法对卫星拍摄的遥感影像(如多光谱、全色波段)进行像素级或对象级的类别划分(如森林、水体、建筑)。其挑战在于:
- 数据复杂性:卫星图像包含多波段(RGB+NIR等)、高分辨率(单图可达数MB)、光照/天气干扰;
- 类别不平衡:某些地物(如道路)样本远少于其他类别(如森林);
- 小目标识别:如车辆、船只等微小对象需高空间分辨率支持。
典型应用场景包括:
- 环境监测:识别非法砍伐区域(森林→裸地类别变化);
- 灾害响应:快速定位地震后的建筑物损毁区域;
- 农业管理:区分作物类型与病虫害区域。
二、Kaggle比赛数据特点与预处理技巧
关键预处理步骤(代码实现)
1. 数据加载与标签解析
原始数据为JPEG格式图像+CSV标签文件(每行对应一张图,标签以空格分隔,如"haze primary"
)。需将文本标签转换为多标签分类所需的二进制矩阵:
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
# 加载标签文件
train_df = pd.read_csv('train_v2.csv')
# 标签列拆分为列表(如"haze primary"→['haze', 'primary'])
train_df['tags'] = train_df['tags'].apply(lambda x: x.split(' '))
# 多标签二值化:将文本标签转为0/1矩阵
mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(train_df['tags']) # 形状:(样本数, 类别数)
print(f"总类别数: {len(mlb.classes_)}") # 输出例如17类
2. 图像增强与归一化
卫星图像常存在光照差异,需通过增强提升模型鲁棒性。使用albumentations
库(专为遥感优化的增强库)实现随机旋转、HSV调整等:
import albumentations as A
from albumentations.pytorch import ToTensorV2
# 定义训练集增强:随机水平翻转+旋转+色彩抖动
train_transform = A.Compose([
A.HorizontalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.RandomBrightnessContrast(p=0.2),
A.HueSaturationValue(p=0.2),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet均值/标准差
ToTensorV2()
])
# 验证集仅做归一化(避免引入偏差)
val_transform = A.Compose([
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
三、深度学习模型构建与核心技巧
模型选择:ResNet50+多标签分类头
卫星图像分类通常采用预训练的CNN(如ResNet、EfficientNet)作为主干网络,修改最后的全连接层适配多标签任务(输出层用Sigmoid激活函数而非Softmax)。
import torch
import torch.nn as nn
from torchvision.models import resnet50
class SatelliteClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
# 加载预训练ResNet50(去掉原分类头)
self.backbone = resnet50(weights='IMAGENET1K_V1')
in_features = self.backbone.fc.in_features # 获取主干输出维度(2048)
# 替换分类头:全连接层+ReLU+Dropout+Sigmoid
self.classifier = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.5), # 防止过拟合
nn.Linear(512, num_classes),
nn.Sigmoid() # 多标签分类需Sigmoid输出概率
)
def forward(self, x):
features = self.backbone(x) # 主干提取特征(形状:[batch, 2048])
logits = self.classifier(features) # 输出每个类别的概率(形状:[batch, 17])
return logits
核心训练技巧
-
损失函数 :多标签分类需使用
BCEWithLogitsLoss
(内置Sigmoid与二元交叉熵,数值更稳定),但本例中分类头已显式使用Sigmoid,故改用BCELoss
:criterion = nn.BCELoss() # 输入为Sigmoid后的概率,目标为0/1标签
-
优化器 :AdamW(带权重衰减)比传统Adam更抗过拟合:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
-
学习率调度 :余弦退火(CosineAnnealingLR)动态调整学习率,避免陷入局部最优:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
四、详细代码分析:训练循环与验证逻辑(重点,>500字)
以下代码完整实现了训练与验证流程,重点解析数据加载、前向传播、损失计算与指标评估:
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.metrics import f1_score
# 自定义Dataset类:加载图像与标签
class SatelliteDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 读取图像(PIL格式)
image = Image.open(self.image_paths[idx]).convert('RGB')
label = self.labels[idx]
# 应用图像增强
if self.transform:
image = self.transform(image=np.array(image))['image'] # albumentations返回字典
return image, torch.FloatTensor(label) # 标签转为FloatTensor(适配BCELoss)
# 初始化数据集与DataLoader
train_dataset = SatelliteDataset(train_image_paths, labels, transform=train_transform)
val_dataset = SatelliteDataset(val_image_paths, val_labels, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
# 训练与验证主循环
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SatelliteClassifier(num_classes=len(mlb.classes_)).to(device)
best_f1 = 0.0
for epoch in range(10): # 训练10轮
model.train()
train_loss = 0.0
for images, targets in train_loader:
images, targets = images.to(device), targets.to(device)
# 前向传播:获取预测概率
outputs = model(images) # 形状:[batch, 17]
# 计算损失(BCELoss要求输入与目标均为0-1概率/标签)
loss = criterion(outputs, targets)
# 反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0) # 累加batch损失
# 计算平均训练损失
train_loss /= len(train_loader.dataset)
# 验证阶段:计算F1分数(多标签分类常用指标)
model.eval()
all_preds, all_targets = [], []
with torch.no_grad():
for images, targets in val_loader:
images, targets = images.to(device), targets.to(device)
outputs = model(images)
# 将概率阈值化为0/1(阈值0.5)
preds = (outputs > 0.5).float()
all_preds.append(preds.cpu().numpy())
all_targets.append(targets.cpu().numpy())
all_preds = np.concatenate(all_preds, axis=0) # 合并所有batch预测
all_targets = np.concatenate(all_targets, axis=0) # 合并所有batch真实标签
# 计算每个类别的F1,再取宏平均(Macro-F1)
f1 = f1_score(all_targets, all_preds, average='macro')
print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val F1={f1:.4f}')
# 保存最佳模型(基于验证F1)
if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), 'best_model.pth')
代码关键点解析:
- 数据流控制 :通过自定义
Dataset
类封装图像读取与增强逻辑,DataLoader
实现多进程并行加载(num_workers=4
加速IO); - 损失计算 :
BCELoss
直接作用于Sigmoid输出的概率与真实标签(0/1),避免了Softmax的多分类约束; - 指标评估 :多标签分类常用宏平均F1(Macro-F1)(各类别F1的算术平均),相比准确率更能反映少数类别的性能;
- 模型保存 :仅保留验证集表现最好的模型(
best_model.pth
),防止过拟合训练集。
五、未来发展趋势
- 多模态融合:结合卫星图像(RGB/NIR)与雷达数据(如Sentinel-1的SAR),提升云层遮挡下的分类鲁棒性;
- Transformer替代CNN:Vision Transformer(ViT)通过全局注意力机制捕捉长距离依赖,在大尺寸卫星图像中表现优于CNN;
- 实时处理:边缘计算部署轻量化模型(如MobileNetV3+知识蒸馏),实现灾情现场的即时分类;
- 自监督预训练 :利用无标签卫星图像(如https://ssl4eo.org/)通过对比学习预训练主干网络,减少对标注数据的依赖。