简述:地理深度学习全域训练PyTorch2.7+TorchGeo等基线

简述:地理深度学习全域训练PyTorch2.7+TorchGeo等基线

PyTorch 2.7.0 + TorchGeo + U-TAE + RF/XGBoost 基线)

技术栈:

PyTorch 2.7.0(官方稳定版)

Python 3.10+

CUDA 11.8/12.6/12.8

TorchGeo:遥感数据加载 / 采样 / 变换

传统 ML:RandomForest + XGBoost 做基线模型

主力模型:U-TAE(时序遥感 / 变化检测 SOTA,适配自然资源调查 + 地形检测)

输出可对接:PostGIS(空间入库)+ GeoServer(标准服务发布)

一、环境安装(一键复制)

bash 复制代码
# PyTorch 2.7.0 + CUDA 12.1 (兼容11.8/12.6/12.8)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# 核心依赖
pip install torchgeo scikit-learn xgboost rasterio geopandas einops pytorch-lightning

二、完整训练代码(可直接运行)

python 复制代码
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# 1) TorchGeo:遥感数据集 + 空间变换
from torchgeo.datasets import Landsat9, Sentinel2
from torchgeo.transforms import AugmentationSequential
import torchgeo.transforms as T

# 2) 传统基线模型:RF + XGBoost
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
import xgboost as xgb

# 3) U-TAE核心组件(时序/空间遥感专用)
from einops import rearrange

# ==============================================
# 【一、数据模块:基于TorchGeo加载自然资源/遥感数据】
# ==============================================
def get_torchgeo_dataloader(batch_size=8, num_workers=4):
    """加载Sentinel/Landsat数据,支持多时序、空间坐标、切片"""
    transform = AugmentationSequential(
        T.Resize((256, 256)),
        T.RandomHorizontalFlip(p=0.5),
        T.Normalize(mean=0.0, std=1.0),
        data_keys=["image", "mask"],
    )

    # 替换为你的本地自然资源调查/地形数据路径
    dataset = Sentinel2(
        root="./data/sentinel2",
        bands=["B2", "B3", "B4", "B8"],  # RGB+NIR
        masks="mask",
        transforms=transform,
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    return loader, dataset


# ==============================================
# 【二、基线模型:RandomForest + XGBoost】
# ==============================================
def train_baseline_models(loader):
    """训练传统机器学习基线,用于对比精度"""
    X_list, y_list = [], []
    for batch in loader:
        img = batch["image"].numpy()
        mask = batch["mask"].numpy()

        # 展平为像素级特征
        b, c, h, w = img.shape
        img_flat = img.reshape(b, c, -1).transpose(0, 2, 1).reshape(-1, c)
        mask_flat = mask.reshape(-1)

        # 过滤无效标签
        valid = mask_flat != 0
        X_list.append(img_flat[valid])
        y_list.append(mask_flat[valid])

    X = np.concatenate(X_list, axis=0)
    y = np.concatenate(y_list, axis=0)

    # 1. Random Forest
    rf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
    rf.fit(X, y)
    rf_pred = rf.predict(X)
    rf_acc = accuracy_score(y, rf_pred)
    rf_f1 = f1_score(y, rf_pred, average="macro")

    # 2. XGBoost
    xgb_model = xgb.XGBClassifier(n_estimators=100)
    xgb_model.fit(X, y)
    xgb_pred = xgb_model.predict(X)
    xgb_acc = accuracy_score(y, xgb_pred)
    xgb_f1 = f1_score(y, xgb_pred, average="macro")

    print("\n===== 基线模型结果 =====")
    print(f"RF  精度: {rf_acc:.4f}, F1: {rf_f1:.4f}")
    print(f"XGB 精度: {xgb_acc:.4f}, F1: {xgb_f1:.4f}")
    return rf, xgb_model


# ==============================================
# 【三、主力模型:U-TAE(遥感时序变化检测SOTA)】
# ==============================================
class TimeAttentionBlock(nn.Module):
    """时间注意力:适配多期自然资源调查时序数据"""
    def __init__(self, dim, heads=4):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)

    def forward(self, x):
        return x + self.attn(self.norm(x), self.norm(x), self.norm(x))[0]


class UTAE(LightningModule):
    """U-TAE:空间U-Net + 时间注意力,完美适配地形/地类变化检测"""
    def __init__(self, in_channels=4, num_classes=5, lr=1e-4):
        super().__init__()
        self.save_hyperparameters()

        # 编码器
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)

        # 时间注意力(核心)
        self.time_attn = TimeAttentionBlock(256)

        # 解码器
        self.dec3 = self.conv_block(256 + 128, 128)
        self.dec2 = self.conv_block(128 + 64, 64)
        self.out = nn.Conv2d(64, num_classes, 1)

        self.loss_fn = nn.CrossEntropyLoss()

    @staticmethod
    def conv_block(in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # 编码
        e1 = self.enc1(x)
        e2 = self.enc2(nn.MaxPool2d(2)(e1))
        e3 = self.enc3(nn.MaxPool2d(2)(e2))

        # 时间注意力
        b, c, h, w = e3.shape
        feat = rearrange(e3, "b c h w -> b (h w) c")
        feat = self.time_attn(feat)
        feat = rearrange(feat, "b (h w) c -> b c h w", h=h, w=w)

        # 解码
        d3 = torch.cat([nn.Upsample(scale_factor=2)(feat), e2], dim=1)
        d3 = self.dec3(d3)
        d2 = torch.cat([nn.Upsample(scale_factor=2)(d3), e1], dim=1)
        d2 = self.dec2(d2)
        return self.out(d2)

    def training_step(self, batch, _):
        pred = self(batch["image"])
        loss = self.loss_fn(pred, batch["mask"].long())
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=self.hparams.lr)


# ==============================================
# 【四、主训练流程】
# ==============================================
if __name__ == "__main__":
    # 设备检查(PyTorch 2.7.0)
    print("PyTorch 版本:", torch.__version__)
    print("CUDA 可用:", torch.cuda.is_available())
    print("CUDA 版本:", torch.version.cuda)

    # 1. 加载数据
    train_loader, dataset = get_torchgeo_dataloader(batch_size=8)

    # 2. 训练基线模型(RF + XGBoost)
    train_baseline_models(train_loader)

    # 3. 训练主力模型:U-TAE
    model = UTAE(in_channels=4, num_classes=5)
    checkpoint = ModelCheckpoint(monitor="train_loss", save_top_k=1)

    trainer = Trainer(
        max_epochs=20,
        accelerator="gpu",
        devices=1,
        callbacks=[checkpoint],
        precision="16-mixed"
    )

    print("\n===== 开始训练 U-TAE 主力模型 =====")
    trainer.fit(model, train_loader)

    # 4. 保存模型(用于推理 + PostGIS/GeoServer发布)
    torch.save(model.state_dict(), "utae_terrain_change_model.pth")
    print("\n模型已保存:utae_terrain_change_model.pth")
    print("可对接:PostGIS空间入库 → GeoServer发布 → 二三维平台展示")

三、代码核心说明

1. 技术栈完全匹配你的要求

PyTorch 2.7.0 + CUDA 11.8/12.x

TorchGeo:负责自然资源调查 / 遥感影像加载、切片、空间变换

U-TAE:主力模型,专门做多时序影像 + 变化检测(国土业务最优)

RF/XGBoost:基线模型,用于精度对比、业务验收基准
2. 业务落地流程

训练:用 TorchGeo 读取自然资源调查数据 + 卫星影像

推理:U-TAE 输出地形 / 地类变化图

入库:结果转矢量 → PostGIS建空间索引

发布:GeoServer发布 WMS/WMTS 服务

展示:接入你的二维 / 三维实景平台
3. 直接适配你的项目

任务:自然资源调查 + 变化快速检测

输入:Sentinel/Landsat/DEM/ 自然资源图斑

输出:变化区域分类图(建设用地 / 耕地 / 地形起伏变化)

本blog地址:https://blog.csdn.net/hsg77

相关推荐
黎阳之光2 小时前
黎阳之光受邀出席上海口岸联合会2026智慧口岸研讨班 无感通关方案获盛赞
大数据·人工智能·算法·安全·数字孪生
有梦想的牛牛2 小时前
GPT-6 能力畅想:当 AI 跨越“理解”走向“共生”
人工智能·gpt
米猴设计师2 小时前
PS电商详情页高效制作:Nano Banana一键生成电商高转化套图(附实操教程)
大数据·图像处理·人工智能·ai·aigc·startai·banana修图
落羽的落羽2 小时前
【Linux系统】深入线程:多线程的互斥与同步原理,封装实现两种生产者消费者模型
java·linux·运维·服务器·c++·人工智能·python
财经资讯数据_灵砚智能2 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年4月17日
人工智能·python·信息可视化·自然语言处理·ai编程
人工智能AI技术2 小时前
批量归一化基础:让模型训练更稳定
人工智能
PNP Robotics2 小时前
集智联机器人(PNP)亮相第三届中国具身智能大会,以“双臂+遥操作“多维方案定义具身交互新范式
大数据·人工智能·python·深度学习·机器人
电子科技圈4 小时前
SmartDV展示完整的边缘与连接IP解决方案,以高速和低功耗特性赋能移动、物联网和媒体处理设备创新
人工智能·嵌入式硬件·mcu·物联网·智能家居·智能硬件·iot