TorchGeo 入门:用 PyTorch 处理遥感数据,从零搭建卫星图像分类模型

TorchGeo 入门:用 PyTorch 处理遥感数据

面向 PyTorch 用户的遥感深度学习第一课。读完你将能用 TorchGeo 加载卫星图像、训练分类模型。


为什么需要 TorchGeo?

PyTorch 的 torchvision 处理自然图像很顺手,但遇到遥感数据就力不从心了:

  • 地理坐标系 --- 图片不是 .png,而是带投影的 GeoTIFF
  • 多光谱通道 --- 不止 RGB,还有近红外、热红外等十几个波段
  • 超大尺寸 --- 一张遥感图动辄上万像素,不能直接塞进 GPU
  • 空间采样 --- 不能随机 crop,要考虑地理重叠和空间自相关

TorchGeo 是 PyTorch 官方生态项目(微软开发),专门解决这些问题。它提供:

  • 50+ 遥感数据集,一行代码下载
  • 地理感知的采样器(随机/网格/预切分)
  • torchvision 和 PyTorch Lightning 无缝对接

一行安装:

bash 复制代码
pip install torchgeo

第一个遥感数据集

TorchGeo 内置了大量经典遥感数据集。以 EuroSAT 为例------27000 张 Sentinel-2 卫星图像,分为 10 种土地覆盖/利用类型:

python 复制代码
from torchgeo.datasets import EuroSAT

dataset = EuroSAT(root="./data", download=True)

print(len(dataset))          # 27000
print(dataset.num_classes)   # 10
print(dataset.classes)       # ['AnnualCrop', 'Forest', 'HerbaceousVegetation',
                             #  'Highway', 'Industrial', 'Pasture',
                             #  'PermanentCrop', 'Residential', 'River', 'SeaLake']

TorchGeo 数据集返回 dict,包含 image(PIL 或 Tensor)和 label(整数):

python 复制代码
sample = dataset[0]
print(sample['image'].shape)   # torch.Size([3, 64, 64])
print(sample['label'])         # 0 → AnnualCrop

其他常用数据集(同样一行加载):

数据集 任务 规模 类别
RESISC45 场景分类 31500 45
UCMerced 土地利用 2100 21
LandCoverAI 土地覆盖 10674 5
BigEarthNet 多标签分类 590k 43

构建数据加载器

遥感任务经常需要自定义 collate_fn,因为数据集返回的是 dict 而非标准 (X, y) 元组:

python 复制代码
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Lambda(lambda x: x.float() / 255.0),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def collate_fn(batch):
    images = torch.stack([transform(b['image']) for b in batch])
    labels = torch.tensor([b['label'] for b in batch])
    return images, labels

loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

提示:遥感图像归一化建议沿用 ImageNet 的均值和标准差。EuroSAT 等数据集本身就是 RGB 三通道,与自然图像分布接近。


迁移学习:ResNet18 分类卫星图

直接用预训练 ResNet,只替换最后的全连接层:

python 复制代码
from torchvision.models import resnet18

model = resnet18(weights=None)  # 或 weights='IMAGENET1K_V1'
model.fc = nn.Linear(512, 10)   # 10 类输出

完整训练循环(3 个 epoch 就能达到不错的准确率):

python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(3):
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

实际效果(EuroSAT,3 epoch,CPU):

Epoch Loss 训练准确率
1 0.45 76%
2 0.22 88%
3 0.13 93%

测试集准确率约 90%。3 个 epoch 就有这个效果,迁移学习的威力。


完整代码

上面的完整可运行代码见 demo.py,关键步骤总结:

markdown 复制代码
加载数据 → 划分训练/测试 → 预处理 → DataLoader
    ↓
ResNet18 + 替换分类头 → 训练 3 epoch
    ↓
测试评估 → 可视化预测结果

运行:

bash 复制代码
pip install torchgeo torch torchvision matplotlib
python demo.py

接下来学什么?

TorchGeo 的能力远不止分类。掌握上面的基础后,可以深入:

  1. 地理空间采样器 --- RandomGeoSampler / GridGeoSampler,从大尺寸遥感图中按地理坐标采样小块
  2. 预训练遥感模型 --- TorchGeo 内置了在 BigEarthNet 上预训练的 ResNet/FCN 权重
  3. 语义分割 --- 用 LandCoverAI + DeepLabV3 做像素级地物分类
  4. 多光谱数据处理 --- 处理 Sentinel-2 的 13 波段图像
  5. 变化检测 --- 对比不同时相的卫星图像,检测地表变化

官方文档:docs.torchgeo.org


参考


首次发布于 2026-05-20 · 掘金 / 知乎

相关推荐
Lihua奏1 天前
从单核到多核:CPU为什么不能再只靠提频变快
深度学习
拾年2751 天前
大模型的"聪明"从哪来?聊聊 AI 数据集的那些事儿
人工智能·深度学习·机器学习
饼干哥哥6 天前
开源Skills|搭建亚马逊动态关键词库系统,每天抓SSS级机会词
人工智能·深度学习·数据分析
武子康7 天前
调查研究-191 SenseVoice 不只是 ASR:把语音从“转文字“升级成“理解状态“
人工智能·深度学习·openai
武子康9 天前
调查研究-189 Kronos 调研:金融 K 线基础模型,是真突破,还是量化圈的新玩具?
人工智能·深度学习·openai
xiao5kou4chang6kai414 天前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
renhongxia114 天前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
计算机科研狗@OUC14 天前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
β添砖java14 天前
深度学习(22)网络中的网络NiN
人工智能·深度学习
Kobebryant-Manba14 天前
深度学习时候d2l报错和使用问题
人工智能·深度学习