TorchGeo 入门:用 PyTorch 处理遥感数据,从零搭建卫星图像分类模型

TorchGeo 入门:用 PyTorch 处理遥感数据

面向 PyTorch 用户的遥感深度学习第一课。读完你将能用 TorchGeo 加载卫星图像、训练分类模型。


为什么需要 TorchGeo?

PyTorch 的 torchvision 处理自然图像很顺手,但遇到遥感数据就力不从心了:

  • 地理坐标系 --- 图片不是 .png,而是带投影的 GeoTIFF
  • 多光谱通道 --- 不止 RGB,还有近红外、热红外等十几个波段
  • 超大尺寸 --- 一张遥感图动辄上万像素,不能直接塞进 GPU
  • 空间采样 --- 不能随机 crop,要考虑地理重叠和空间自相关

TorchGeo 是 PyTorch 官方生态项目(微软开发),专门解决这些问题。它提供:

  • 50+ 遥感数据集,一行代码下载
  • 地理感知的采样器(随机/网格/预切分)
  • torchvision 和 PyTorch Lightning 无缝对接

一行安装:

bash 复制代码
pip install torchgeo

第一个遥感数据集

TorchGeo 内置了大量经典遥感数据集。以 EuroSAT 为例------27000 张 Sentinel-2 卫星图像,分为 10 种土地覆盖/利用类型:

python 复制代码
from torchgeo.datasets import EuroSAT

dataset = EuroSAT(root="./data", download=True)

print(len(dataset))          # 27000
print(dataset.num_classes)   # 10
print(dataset.classes)       # ['AnnualCrop', 'Forest', 'HerbaceousVegetation',
                             #  'Highway', 'Industrial', 'Pasture',
                             #  'PermanentCrop', 'Residential', 'River', 'SeaLake']

TorchGeo 数据集返回 dict,包含 image(PIL 或 Tensor)和 label(整数):

python 复制代码
sample = dataset[0]
print(sample['image'].shape)   # torch.Size([3, 64, 64])
print(sample['label'])         # 0 → AnnualCrop

其他常用数据集(同样一行加载):

数据集 任务 规模 类别
RESISC45 场景分类 31500 45
UCMerced 土地利用 2100 21
LandCoverAI 土地覆盖 10674 5
BigEarthNet 多标签分类 590k 43

构建数据加载器

遥感任务经常需要自定义 collate_fn,因为数据集返回的是 dict 而非标准 (X, y) 元组:

python 复制代码
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Lambda(lambda x: x.float() / 255.0),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def collate_fn(batch):
    images = torch.stack([transform(b['image']) for b in batch])
    labels = torch.tensor([b['label'] for b in batch])
    return images, labels

loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

提示:遥感图像归一化建议沿用 ImageNet 的均值和标准差。EuroSAT 等数据集本身就是 RGB 三通道,与自然图像分布接近。


迁移学习:ResNet18 分类卫星图

直接用预训练 ResNet,只替换最后的全连接层:

python 复制代码
from torchvision.models import resnet18

model = resnet18(weights=None)  # 或 weights='IMAGENET1K_V1'
model.fc = nn.Linear(512, 10)   # 10 类输出

完整训练循环(3 个 epoch 就能达到不错的准确率):

python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(3):
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

实际效果(EuroSAT,3 epoch,CPU):

Epoch Loss 训练准确率
1 0.45 76%
2 0.22 88%
3 0.13 93%

测试集准确率约 90%。3 个 epoch 就有这个效果,迁移学习的威力。


完整代码

上面的完整可运行代码见 demo.py,关键步骤总结:

markdown 复制代码
加载数据 → 划分训练/测试 → 预处理 → DataLoader
    ↓
ResNet18 + 替换分类头 → 训练 3 epoch
    ↓
测试评估 → 可视化预测结果

运行:

bash 复制代码
pip install torchgeo torch torchvision matplotlib
python demo.py

接下来学什么?

TorchGeo 的能力远不止分类。掌握上面的基础后,可以深入:

  1. 地理空间采样器 --- RandomGeoSampler / GridGeoSampler,从大尺寸遥感图中按地理坐标采样小块
  2. 预训练遥感模型 --- TorchGeo 内置了在 BigEarthNet 上预训练的 ResNet/FCN 权重
  3. 语义分割 --- 用 LandCoverAI + DeepLabV3 做像素级地物分类
  4. 多光谱数据处理 --- 处理 Sentinel-2 的 13 波段图像
  5. 变化检测 --- 对比不同时相的卫星图像,检测地表变化

官方文档:docs.torchgeo.org


参考


首次发布于 2026-05-20 · 掘金 / 知乎

相关推荐
茗创科技6 小时前
Nat Hum Behav | 特征选择会导致基于脑影像的机器学习生物标志物产生迥异的神经生物学解释
python·深度学习·机器学习·matlab·脑网络
dr_yingli6 小时前
MedGemma皮肤肿瘤6分类LLM fineturn流程
人工智能·深度学习
放下华子我只抽RuiKe56 小时前
React 从入门到生产(一):JSX 与组件思维
前端·javascript·人工智能·pytorch·深度学习·react.js·前端框架
RSTJ_16256 小时前
PYTHON+AI LLM DAY FIFITY
人工智能·深度学习
qqqweiweiqq6 小时前
Jetson Orin nx 无法train pi0
人工智能·python·深度学习
koharu1236 小时前
PointNet 与 PointNet++ 详解
人工智能·深度学习·神经网络·三维点云
动物园猫6 小时前
棉花病害图像分类数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·分类
人工智能培训7 小时前
AI人工智能未来发展趋势
人工智能·深度学习·机器学习·docker·容器
生物信息与育种7 小时前
PlantBiMoE开源:轻量高效的植物基因组基础模型
人工智能·深度学习·职场和发展·数据分析·r语言