GeoSeg 框架解析

总的来看:

config 配置文件
geoseg 定义了核心组件:models datasets losses
tools 辅助工具箱

config :


里面包含多个遥感数据集,内部包含不同的模型

看一个例子
复制代码
导入
from torch.utils.data import DataLoader
from geoseg.losses import *
from geoseg.datasets.vaihingen_dataset import *
from geoseg.models.UNetFormer import UNetFormer
from tools.utils import Lookahead, process_model_params

训练超参数
from torch.utils.data import DataLoader
from geoseg.losses import *
from geoseg.datasets.vaihingen_dataset import *
from geoseg.models.UNetFormer import UNetFormer
from tools.utils import Lookahead, process_model_params

网络定义
net = UNetFormer(num_classes=num_classes)

损失函数
loss = UnetFormerLoss(ignore_index=ignore_index)
use_aux_loss = True

数据加载器
train_dataset = VaihingenDataset(data_root='data/vaihingen/train', ...)
val_dataset = VaihingenDataset(...)
...
train_loader = DataLoader(dataset=train_dataset, ...)
val_loader = DataLoader(dataset=val_dataset, ...)

优化器
layerwise_params = {"backbone.*": dict(lr=backbone_lr, ...)}
net_params = process_model_params(net, ...)
base_optimizer = torch.optim.AdamW(net_params, lr=lr, ...)
optimizer = Lookahead(base_optimizer)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(...)

geoseg

models--->所有神经网络骨架库

datasets-->数据集库

losses --->损失函数库

datasets

除了每个数据集的datasets 之外,有个transform是数据增强策略

transform: 各种各样的数据增强策略
看一个dataset例子:
复制代码
全局常量
类别
CLASSES = ('ImSurf', 'Building', 'LowVeg', 'Tree', 'Car', 'Clutter')
可视化颜色
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0], [255, 204, 0], [255, 0, 0]]

ORIGIN_IMG_SIZE = (1024, 1024)
INPUT_IMG_SIZE = (1024, 1024)
TEST_IMG_SIZE = (1024, 1024)

数据增强流水线
def get_training_transform(): ...
def train_aug(img, mask): ...
def get_val_transform(): ...
def val_aug(img, mask): ...
这些都是对数据进行数据增强

核心数据集类
class VaihingenDataset(Dataset):
    def __init__(...): ...构造函数,保存传入参数,调用get_img_ids读取images_1024目录下的所有文件名,并且保存在self.img_ids 当中
    def __getitem__(...): ...
    def __len__(...): ...
    def get_img_ids(...): ...用来确保img_dir 和mask_dir能对应的上
    def load_img_and_mask(...): ...加载数据
    def load_mosaic_img_and_mask(...): ...高级数据增强
    
可视化工具
def show_img_mask_seg(...): ...
def show_seg(...): ...
def show_mask(...): ...

losses

各种各样loss 函数

models:定义了不同的models 神经网络

看一个例子

这个例子没有encoder 模块是因为,他的encoder 模块就只是backbone 特征提取,如果特征提取后加入混合注意力之类的东西,就应该写一个encoder类

此外,models 一般都是把组件,基础架构模块定义好,核心模块定义好,在encoder 和 decoder 内部,按照论文结构排列组合
encoder 和 decoder 会在最后的模型封装中,在forward 中应用

复制代码
基础架构模块
class ConvBNReLU(nn.Sequential): ...
class ConvBN(nn.Sequential): ...
class Conv(nn.Sequential): ...
class SeparableConvBNReLU(nn.Sequential): ...
class Mlp(nn.Module): ...

核心注意力模块
class GlobalLocalAttention(nn.Module): ...

Transformer 模块
class GlobalLocalAttention(nn.Module): ...

解码辅助模块
class WF(nn.Module): ...
class FeatureRefinementHead(nn.Module): ...
class AuxHead(nn.Module): ...

解码器
class Decoder(nn.Module): ...

最终模型
class UNetFormer(nn.Module): ...
构造函数和前向传播

tools

数据预处理脚本

模型配置脚本

模型评估脚本

相关推荐
BruceWooCoder1 天前
从零打造云端AI视频生成服务:基于CogVideoX和MCP协议的完整实践
人工智能·音视频
大千AI助手1 天前
汉明距离:度量差异的基石与AI应用
人工智能·机器学习·距离度量·汉明距离·大千ai助手·hammingdistance·纠错码
我很哇塞耶1 天前
AWS AgentCore重磅升级,三大新功能重塑AI代理开发体验
人工智能·ai·大模型
说私域1 天前
社群媒体时代下“开源AI智能名片链动2+1模式S2B2C商城小程序”对社群运营的重要性研究
人工智能·开源·媒体
Akamai中国1 天前
加速采用安全的企业级 Kubernetes 环境
人工智能·云计算·云服务·云存储
AI科技星1 天前
时空的几何之歌:论统一场论动量公式 P = m(C - V) 的完备重构、量化哲学诠释与终极验证
数据结构·人工智能·算法·机器学习·计算机视觉·重构
子午1 天前
【农作物谷物识别系统】Python+TensorFlow+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
断春风1 天前
Java 集成 AI 大模型最佳实践:从零到一打造智能化后端
java·人工智能·ai
大千AI助手1 天前
基于实例的学习:最近邻算法及其现代演进
人工智能·算法·机器学习·近邻算法·knn·大千ai助手·基于实例的学习
淘源码d1 天前
智慧工地企项一体化平台,Spring Cloud +UniApp 智慧工地源码,BIM+AI+物联网,施工全过程数字化智慧工地管理平台
java·人工智能·物联网·智慧工地·智慧工地源码·智慧工地app·数字工地