简述:地理深度学习全域训练PyTorch2.7+TorchGeo等基线

PyTorch 2.7.0 + TorchGeo + U-TAE + RF/XGBoost 基线)
技术栈:
PyTorch 2.7.0(官方稳定版)
Python 3.10+
CUDA 11.8/12.6/12.8
TorchGeo:遥感数据加载 / 采样 / 变换
传统 ML:RandomForest + XGBoost 做基线模型
主力模型:U-TAE(时序遥感 / 变化检测 SOTA,适配自然资源调查 + 地形检测)
输出可对接:PostGIS(空间入库)+ GeoServer(标准服务发布)
一、环境安装(一键复制)
bash
# PyTorch 2.7.0 + CUDA 12.1 (兼容11.8/12.6/12.8)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# 核心依赖
pip install torchgeo scikit-learn xgboost rasterio geopandas einops pytorch-lightning
二、完整训练代码(可直接运行)
python
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# 1) TorchGeo:遥感数据集 + 空间变换
from torchgeo.datasets import Landsat9, Sentinel2
from torchgeo.transforms import AugmentationSequential
import torchgeo.transforms as T
# 2) 传统基线模型:RF + XGBoost
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
import xgboost as xgb
# 3) U-TAE核心组件(时序/空间遥感专用)
from einops import rearrange
# ==============================================
# 【一、数据模块:基于TorchGeo加载自然资源/遥感数据】
# ==============================================
def get_torchgeo_dataloader(batch_size=8, num_workers=4):
"""加载Sentinel/Landsat数据,支持多时序、空间坐标、切片"""
transform = AugmentationSequential(
T.Resize((256, 256)),
T.RandomHorizontalFlip(p=0.5),
T.Normalize(mean=0.0, std=1.0),
data_keys=["image", "mask"],
)
# 替换为你的本地自然资源调查/地形数据路径
dataset = Sentinel2(
root="./data/sentinel2",
bands=["B2", "B3", "B4", "B8"], # RGB+NIR
masks="mask",
transforms=transform,
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
return loader, dataset
# ==============================================
# 【二、基线模型:RandomForest + XGBoost】
# ==============================================
def train_baseline_models(loader):
"""训练传统机器学习基线,用于对比精度"""
X_list, y_list = [], []
for batch in loader:
img = batch["image"].numpy()
mask = batch["mask"].numpy()
# 展平为像素级特征
b, c, h, w = img.shape
img_flat = img.reshape(b, c, -1).transpose(0, 2, 1).reshape(-1, c)
mask_flat = mask.reshape(-1)
# 过滤无效标签
valid = mask_flat != 0
X_list.append(img_flat[valid])
y_list.append(mask_flat[valid])
X = np.concatenate(X_list, axis=0)
y = np.concatenate(y_list, axis=0)
# 1. Random Forest
rf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
rf.fit(X, y)
rf_pred = rf.predict(X)
rf_acc = accuracy_score(y, rf_pred)
rf_f1 = f1_score(y, rf_pred, average="macro")
# 2. XGBoost
xgb_model = xgb.XGBClassifier(n_estimators=100)
xgb_model.fit(X, y)
xgb_pred = xgb_model.predict(X)
xgb_acc = accuracy_score(y, xgb_pred)
xgb_f1 = f1_score(y, xgb_pred, average="macro")
print("\n===== 基线模型结果 =====")
print(f"RF 精度: {rf_acc:.4f}, F1: {rf_f1:.4f}")
print(f"XGB 精度: {xgb_acc:.4f}, F1: {xgb_f1:.4f}")
return rf, xgb_model
# ==============================================
# 【三、主力模型:U-TAE(遥感时序变化检测SOTA)】
# ==============================================
class TimeAttentionBlock(nn.Module):
"""时间注意力:适配多期自然资源调查时序数据"""
def __init__(self, dim, heads=4):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
def forward(self, x):
return x + self.attn(self.norm(x), self.norm(x), self.norm(x))[0]
class UTAE(LightningModule):
"""U-TAE:空间U-Net + 时间注意力,完美适配地形/地类变化检测"""
def __init__(self, in_channels=4, num_classes=5, lr=1e-4):
super().__init__()
self.save_hyperparameters()
# 编码器
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
# 时间注意力(核心)
self.time_attn = TimeAttentionBlock(256)
# 解码器
self.dec3 = self.conv_block(256 + 128, 128)
self.dec2 = self.conv_block(128 + 64, 64)
self.out = nn.Conv2d(64, num_classes, 1)
self.loss_fn = nn.CrossEntropyLoss()
@staticmethod
def conv_block(in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
# 编码
e1 = self.enc1(x)
e2 = self.enc2(nn.MaxPool2d(2)(e1))
e3 = self.enc3(nn.MaxPool2d(2)(e2))
# 时间注意力
b, c, h, w = e3.shape
feat = rearrange(e3, "b c h w -> b (h w) c")
feat = self.time_attn(feat)
feat = rearrange(feat, "b (h w) c -> b c h w", h=h, w=w)
# 解码
d3 = torch.cat([nn.Upsample(scale_factor=2)(feat), e2], dim=1)
d3 = self.dec3(d3)
d2 = torch.cat([nn.Upsample(scale_factor=2)(d3), e1], dim=1)
d2 = self.dec2(d2)
return self.out(d2)
def training_step(self, batch, _):
pred = self(batch["image"])
loss = self.loss_fn(pred, batch["mask"].long())
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return optim.AdamW(self.parameters(), lr=self.hparams.lr)
# ==============================================
# 【四、主训练流程】
# ==============================================
if __name__ == "__main__":
# 设备检查(PyTorch 2.7.0)
print("PyTorch 版本:", torch.__version__)
print("CUDA 可用:", torch.cuda.is_available())
print("CUDA 版本:", torch.version.cuda)
# 1. 加载数据
train_loader, dataset = get_torchgeo_dataloader(batch_size=8)
# 2. 训练基线模型(RF + XGBoost)
train_baseline_models(train_loader)
# 3. 训练主力模型:U-TAE
model = UTAE(in_channels=4, num_classes=5)
checkpoint = ModelCheckpoint(monitor="train_loss", save_top_k=1)
trainer = Trainer(
max_epochs=20,
accelerator="gpu",
devices=1,
callbacks=[checkpoint],
precision="16-mixed"
)
print("\n===== 开始训练 U-TAE 主力模型 =====")
trainer.fit(model, train_loader)
# 4. 保存模型(用于推理 + PostGIS/GeoServer发布)
torch.save(model.state_dict(), "utae_terrain_change_model.pth")
print("\n模型已保存:utae_terrain_change_model.pth")
print("可对接:PostGIS空间入库 → GeoServer发布 → 二三维平台展示")
三、代码核心说明
1. 技术栈完全匹配你的要求
PyTorch 2.7.0 + CUDA 11.8/12.x
TorchGeo:负责自然资源调查 / 遥感影像加载、切片、空间变换
U-TAE:主力模型,专门做多时序影像 + 变化检测(国土业务最优)
RF/XGBoost:基线模型,用于精度对比、业务验收基准
2. 业务落地流程
训练:用 TorchGeo 读取自然资源调查数据 + 卫星影像
推理:U-TAE 输出地形 / 地类变化图
入库:结果转矢量 → PostGIS建空间索引
发布:GeoServer发布 WMS/WMTS 服务
展示:接入你的二维 / 三维实景平台
3. 直接适配你的项目
任务:自然资源调查 + 变化快速检测
输入:Sentinel/Landsat/DEM/ 自然资源图斑
输出:变化区域分类图(建设用地 / 耕地 / 地形起伏变化)
本blog地址:https://blog.csdn.net/hsg77