【3D AICG 系列-6】OmniPart 训练流程梳理

系列文章目录


文章目录


OmniPart 训练代码技术详解

本文档详细解析 OmniPart 的训练代码逻辑,包括数据处理流程、模型架构、训练循环和关键设计点。

目录

  1. 训练流程概述
  2. 数据处理流程
  3. 模型架构
  4. 训练循环
  5. 关键代码解析

训练流程概述

整体架构

OmniPart 是一个基于 Flow Matching 的 part-aware 3D 生成模型,其训练流程如下:
模型前向传播
训练阶段
数据加载阶段
Part Mesh
Voxelize
SLat Encoder
SLat Features
Render Images
DINO Encoder
Image Features
Mask Files
Mask Processing
Ordered Mask
Dataset
Collate Batch
Flow Matching Loss
Backward Pass
Update Model
Sparse Tensor
Mask Embedding
Part-wise Batch
Transformer Blocks
Velocity Prediction

训练阶段说明

OmniPart 的训练分为两个主要阶段:

  1. Stage 1: Sparse Structure Flow

    • 学习从图像到稀疏结构的映射
    • 输出:稀疏的 3D 结构表示
  2. Stage 2: Structured Latent Flow (Part-aware)

    • 学习 part-level 的 SLat 生成
    • 输入:图像 + Mask
    • 输出:Part-level SLat features

本文档主要关注 Stage 2 的训练流程。


数据处理流程

Dataset 类:StructuredLatentPartDataset

位置:training/datasets/structured_latent_part.py

核心数据结构
python 复制代码
# 数据文件结构:{uuid[:2]}/{uuid}/all_latent.npz
{
    'coords': [N, 3],      # Voxel 坐标
    'feats': [N, 8],       # SLat features (8维)
    'offsets': [num_parts+1]  # 每个 part 的起始索引
}
get_instance 方法详解

这是数据加载的核心方法,负责处理 part-level 数据:

python 复制代码
def get_instance(self, uuid):
    """
    处理所有 parts 和整体模型数据,合并为单一输出
    
    关键步骤:
    1. 加载 all_latent.npz(包含 overall + 所有 parts)
    2. 对每个 part,计算 bbox 并添加 noise mask
    3. 为 part 和 overall 添加 noise_mask_score 标记
    4. 返回合并后的坐标、特征和 part_layouts
    """
    # 1. 加载数据
    data = np.load(os.path.join(self.data_root, uuid[:2], uuid, 'all_latent.npz'))
    all_coords = data['coords'].astype(np.int32)
    all_feats = data['feats'].astype(np.float32)
    offsets = data['offsets'].astype(np.int32)
    
    all_coords_wnoise = []
    all_feats_wnoise = []
    part_layouts = []
    start_idx = 0
    
    # 2. 处理每个 part(i=0 是 overall,i>0 是 parts)
    for i in range(len(offsets) - 1):
        if i == 0:
            # Overall shape:直接添加,标记为 positive
            overall_coords = all_coords[offsets[i]:offsets[i+1]]
            overall_feats = all_feats[offsets[i]:offsets[i+1]]
            overall_ids = self.coords_to_ids(overall_coords)
            part_layouts.append(slice(start_idx, start_idx + overall_coords.shape[0]))
            start_idx += overall_coords.shape[0]
        else:
            # Part:计算 bbox,添加 noise mask
            part_coords = all_coords[offsets[i]:offsets[i+1]]
            part_feats = all_feats[offsets[i]:offsets[i+1]]
            
            # 计算 part 的 bbox
            part_bbox_min = part_coords.min(axis=0)
            part_bbox_max = part_coords.max(axis=0)
            
            # Bbox augmentation(可选)
            if self.aug_bbox_range is not None:
                # 随机扩展 bbox 范围,增加训练鲁棒性
                aug_bbox_min = np.random.randint(...)
                aug_bbox_max = np.random.randint(...)
                part_bbox_min = np.clip(part_bbox_min - aug_bbox_min, 0, 63)
                part_bbox_max = np.clip(part_bbox_max + aug_bbox_max, 0, 63)
            
            # 关键:Noise Mask 机制
            # 找到在 part bbox 内但不在 part 内的 voxels(这些是 noise)
            bbox_mask = np.all((overall_coords >= part_bbox_min) & 
                              (overall_coords <= part_bbox_max), axis=1)
            part_ids = self.coords_to_ids(part_coords)
            part_in_overall = np.isin(overall_ids, part_ids)
            noise_mask = np.logical_and(bbox_mask, np.logical_not(part_in_overall))
            
            # 提取 noise voxels
            noise_coords = overall_coords[noise_mask]
            noise_feats = overall_feats[noise_mask]
            
            # 添加 noise_mask_score:
            # - Part voxels: +noise_mask_score (positive)
            # - Noise voxels: -noise_mask_score (negative)
            noise_feats = np.concatenate((noise_feats, 
                                        np.full((noise_feats.shape[0], 1), 
                                               -self.noise_mask_score)), axis=1)
            part_feats = np.concatenate((part_feats, 
                                        np.full((part_feats.shape[0], 1), 
                                               self.noise_mask_score)), axis=1)
            
            # 合并 part 和 noise
            part_coords = np.concatenate((part_coords, noise_coords), axis=0)
            part_feats = np.concatenate((part_feats, noise_feats), axis=0)
            
            part_layouts.append(slice(start_idx, start_idx + part_coords.shape[0]))
            start_idx += part_coords.shape[0]
            
            all_coords_wnoise.append(torch.from_numpy(part_coords))
            all_feats_wnoise.append(torch.from_numpy(part_feats))
    
    # 3. 添加 overall shape(标记为 positive)
    overall_feats = np.concatenate((overall_feats, 
                                   np.full((overall_feats.shape[0], 1), 
                                          self.noise_mask_score)), axis=1)
    all_coords_wnoise.insert(0, torch.from_numpy(overall_coords))
    all_feats_wnoise.insert(0, torch.from_numpy(overall_feats))
    
    # 4. 合并所有数据
    combined_coords = torch.cat(all_coords_wnoise, dim=0).int()
    combined_feats = torch.cat(all_feats_wnoise, dim=0).float()
    
    # 5. 归一化(如果配置了)
    if self.normalization is not None:
        combined_feats = (combined_feats - self.mean) / self.std
    
    return {
        'coords': combined_coords,
        'feats': combined_feats,
        'part_layouts': part_layouts,  # 每个 part 的 slice,用于 batch 处理
    }
Noise Mask 机制的作用

Noise Mask 是 OmniPart 的关键设计,用于实现 Coverage Loss 的隐式约束:

  1. 防止 Part 缺失:通过添加 overall shape,确保模型学习完整的几何覆盖
  2. 区分 Part 和背景:通过 noise_mask_score 标记,模型学习区分真实 part 和背景噪声
  3. Bbox 约束:只考虑 part bbox 内的 voxels,避免全局噪声干扰

图像和 Mask 条件处理

位置:training/datasets/components.py

ImageConditionedMixin.get_instance
python 复制代码
def get_instance(self, uuid, select_method='random_seg'):
    """
    加载图像和 mask 作为条件
    
    流程:
    1. 随机选择一个 view(0-25)
    2. 加载图像:color_{index:04d}.webp
    3. 加载 mask:mask_{index:04d}.exr
    4. 处理 mask:排序 parts(从下到上),下采样到 DINO patch size
    """
    pack = super().get_instance(uuid)
    
    # 随机选择 view
    data_dir = os.path.join(self.img_data_root, uuid[:2], uuid)
    random_index = random.choice([0, 1, 2, ..., 25])
    
    image_path = os.path.join(data_dir, f"color_{random_index:04d}.webp")
    mask_path = os.path.join(data_dir, f"mask_{random_index:04d}.exr")
    
    # 加载并处理 mask
    mask, mask_vis, ordered_mask_dino = self.load_bottom_up_mask(mask_path, size=(518, 518))
    pack['ordered_mask_dino'] = ordered_mask_dino  # [37, 37] 用于 DINO
    
    # 加载图像
    image, image_rgb = self.load_image(image_path)
    pack['cond'] = image  # [3, 518, 518]
    
    return pack
load_bottom_up_mask 方法
python 复制代码
def load_bottom_up_mask(self, mask_path, size=(518, 518)):
    """
    加载 mask 并按从下到上的顺序排序 parts
    
    关键步骤:
    1. 加载原始 mask(EXR 格式)
    2. 下采样到 DINO patch size (37x37)
    3. 根据 y 坐标排序 parts(从下到上)
    4. 重新映射 part IDs
    """
    # 1. 加载 mask
    mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
    mask = cv2.resize(mask, size, interpolation=cv2.INTER_NEAREST)
    
    # 2. 下采样到 DINO patch size (37x37)
    mask_dino = cv2.resize(mask, (37, 37), interpolation=cv2.INTER_NEAREST)
    mask_dino = np.array(mask_dino, dtype=np.int32)[..., 0]
    mask_org = np.array(mask, dtype=np.int32)[..., 0]
    
    # 3. 找到所有 part IDs
    unique_indices = np.unique(mask_org)
    unique_indices = unique_indices[unique_indices > 0]  # 排除背景 (0)
    
    # 4. 计算每个 part 的 y 坐标(最下方的 y 值)
    part_positions = {}
    for idx in unique_indices:
        y_coords, _ = np.where(mask_org == idx)
        if len(y_coords) > 0:
            part_positions[idx] = np.max(y_coords)  # 最下方的 y
    
    # 5. 按 y 坐标从大到小排序(从下到上)
    sorted_parts = sorted(part_positions.items(), key=lambda x: -x[1])
    
    # 6. 重新映射 part IDs(1, 2, 3, ...)
    index_map = {}
    for new_idx, (old_idx, _) in enumerate(sorted_parts, 1):
        index_map[old_idx] = new_idx
    
    # 7. 创建有序的 mask
    ordered_mask = np.zeros_like(mask_org)
    ordered_mask_dino = np.zeros_like(mask_dino)
    for old_idx, new_idx in index_map.items():
        ordered_mask[mask_org == old_idx] = new_idx
        ordered_mask_dino[mask_dino == old_idx] = new_idx
    
    ordered_mask_dino = torch.from_numpy(ordered_mask_dino).long()
    
    return mask, mask_vis, ordered_mask_dino  # [37, 37] for DINO

为什么需要排序?

  • 确保 part 的顺序一致性(从下到上)
  • 便于模型学习 part 之间的空间关系
  • 与 part_layouts 中的 part 顺序对应

模型架构

StructuredLatentFlow 模型

位置:training/models/structured_latent_flow.py

整体架构

Input SparseTensor
Input Layer
Mask Embedding
Part-wise Batch IDs
Input Blocks

Downsampling
Transformer Blocks

Cross-Attention
Output Blocks

Upsampling
Output Layer
Velocity Prediction
Image Features
Mask Features

Forward 方法详解
python 复制代码
def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor, **kwargs):
    """
    前向传播流程
    
    Args:
        x: SparseTensor [N, 9] (8维 SLat + 1维 noise_mask_score)
        t: Timestep [B]
        cond: DINO features [B, 1374, 1024]
        kwargs: 
            - ordered_mask_dino: [B, 37, 37]
            - part_layouts: List[List[slice]]
    
    Returns:
        Output SparseTensor [N, 8] (velocity prediction)
    """
    input_dtype = x.dtype
    
    # ========== 1. Mask Embedding ==========
    masks = kwargs['ordered_mask_dino']  # [B, 37, 37]
    masks = masks.long()
    masks = rearrange(masks, 'b h w -> b (h w)')  # [B, 1369]
    
    # Embedding: part ID -> 128维 -> 1024维
    masks_emb = self.group_embedding(masks)  # [B, 1369, 128]
    masks_emb = self.group_emb_proj(masks_emb)  # [B, 1369, 1024]
    
    # 对齐到 cond 的长度(1374 = 1369 + 5 special tokens)
    group_emb = torch.zeros((cond.shape[0], cond.shape[1], masks_emb.shape[2]), 
                           device=cond.device, dtype=cond.dtype)
    group_emb[:, :masks_emb.shape[1], :] = masks_emb
    
    # 融合 mask embedding 到 DINO features
    cond = cond + group_emb  # [B, 1374, 1024]
    cond = cond.type(self.dtype)
    
    # ========== 2. Part-wise Batch Processing ==========
    original_batch_ids = x.coords[:, 0].clone()  # 原始 batch IDs
    
    # 创建新的 batch IDs:每个 part 一个独立的 batch ID
    new_batch_ids = torch.zeros_like(original_batch_ids)
    part_layouts = kwargs['part_layouts']
    
    part_id = 0
    len_before = 0
    batch_last_partid = []
    
    for batch_idx, part_layout in enumerate(part_layouts):
        # part_layout 是 List[slice],每个 slice 对应一个 part
        for layout_idx, layout in enumerate(part_layout):
            adjusted_layout = slice(layout.start + len_before, 
                                  layout.stop + len_before, 
                                  layout.step)
            new_batch_ids[adjusted_layout] = part_id
            part_id += 1
        
        batch_last_partid.append(part_id)
        len_before += part_layout[-1].stop
    
    # 更新 coordinates:使用 part-wise batch IDs
    x = sp.SparseTensor(
        feats=x.feats,
        coords=torch.cat([new_batch_ids.view(-1, 1), x.coords[:, 1:]], dim=1)
    )
    
    # ========== 3. Input Processing ==========
    x = self.input_layer(x).type(self.dtype)  # Project to model dim
    
    # Timestep embedding
    t_emb = self.t_embedder(t)  # [B, hidden_dim]
    if self.share_mod:
        t_emb = self.adaLN_modulation(t_emb)
    t_emb = t_emb.type(self.dtype)
    
    # 为每个 part 复制 timestep embedding
    t_emb_updown = []
    for batch_idx, part_layout in enumerate(part_layouts):
        t_emb_updown_batch = t_emb[batch_idx:batch_idx+1].repeat(len(part_layout), 1)
        t_emb_updown.append(t_emb_updown_batch)
    t_emb_updown = torch.cat(t_emb_updown, dim=0).type(self.dtype)
    
    # ========== 4. Downsampling (Input Blocks) ==========
    skips = []
    for block in self.input_blocks:
        x = block(x, t_emb_updown)
        skips.append(x.feats)  # 保存 skip connections
    
    # ========== 5. Transformer Blocks ==========
    part_wise_batch_ids = x.coords[:, 0].clone()
    
    # 转换回 batch-wise IDs(用于 transformer)
    new_transformer_batch_ids = torch.zeros_like(part_wise_batch_ids)
    part_ids_in_each_object = torch.zeros_like(part_wise_batch_ids)
    
    start_reform = 0
    last_part_id = 0
    for part_id in batch_last_partid:
        mask = (part_wise_batch_ids >= last_part_id) & (part_wise_batch_ids < part_id)
        new_transformer_batch_ids[mask] = start_reform
        part_ids_in_each_object[mask] = part_wise_batch_ids[mask] - last_part_id
        last_part_id = part_id
        start_reform += 1
    
    # 更新 coordinates:使用 batch-wise IDs
    h = sp.SparseTensor(
        feats=x.feats,
        coords=torch.cat([new_transformer_batch_ids.view(-1, 1), x.coords[:, 1:]], dim=1)
    )
    
    # 添加位置编码
    if self.pe_mode == "ape":
        # Absolute positional encoding
        h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
        
        # Part-wise positional encoding(overall 是 0,parts 是 1, 2, 3, ...)
        part_pe = self.layer_pe(part_ids_in_each_object)
        part_pe = self.layer_pe_proj(part_pe)
        h = h + part_pe.type(self.dtype)
    
    # Transformer blocks with cross-attention
    for block in self.blocks:
        h = block(h, t_emb, cond)  # cond 用于 cross-attention
    
    # ========== 6. Upsampling (Output Blocks) ==========
    # 恢复 part-wise batch IDs
    h = x.replace(feats=h.feats, 
                 coords=torch.cat([part_wise_batch_ids.view(-1, 1), h.coords[:, 1:]], dim=1))
    
    # Upsampling with skip connections
    for block, skip in zip(self.out_blocks, reversed(skips)):
        if self.use_skip_connection:
            h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb_updown)
        else:
            h = block(h, t_emb_updown)
    
    # ========== 7. Output ==========
    h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
    h = self.out_layer(h.type(input_dtype))
    
    # 恢复原始 batch IDs
    h = sp.SparseTensor(
        feats=h.feats,
        coords=torch.cat([original_batch_ids.view(-1, 1), h.coords[:, 1:]], dim=1)
    )
    
    return h
关键设计点
  1. Mask Embedding

    • 将 mask 的 part IDs 转换为 embedding
    • 与 DINO features 相加,提供 part 信息
  2. Part-wise Batch Processing

    • 每个 part 独立处理(独立的 batch ID)
    • 在 transformer 阶段转换回 batch-wise
    • 便于处理不同数量的 parts
  3. Part Positional Encoding

    • Overall shape: 0
    • Parts: 1, 2, 3, ...(按从下到上排序)
    • 帮助模型理解 part 的层次关系

训练循环

Trainer 类结构

位置:training/trainers/flow_matching/sparse_flow_matching.py

SparseFlowMatchingTrainer

继承关系:

复制代码
Trainer (base.py)
  └── FlowMatchingTrainer (flow_matching.py)
      └── SparseFlowMatchingTrainer (sparse_flow_matching.py)
          └── ImageConditionedSparseFlowMatchingTrainer (with ImageConditionedMixin)
training_losses 方法
python 复制代码
def training_losses(
    self,
    x_0: sp.SparseTensor,
    cond=None,
    **kwargs
) -> Tuple[Dict, Dict]:
    """
    计算 Flow Matching 训练损失
    
    流程:
    1. 生成随机噪声(与 x_0 相同的 sparsity pattern)
    2. 采样随机 timestep t
    3. 应用 diffusion:x_t = (1-t) * x_0 + t * noise
    4. 预测 velocity field
    5. 计算 MSE loss
    """
    # 1. 生成噪声(保持相同的 sparsity pattern)
    noise = x_0.replace(torch.randn_like(x_0.feats))
    
    # 2. 采样随机 timestep t ~ U(0, 1)
    t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
    
    # 3. 应用 diffusion process
    # x_t = (1 - t) * x_0 + t * noise
    x_t = self.diffuse(x_0, t, noise=noise)
    
    # 4. 处理条件输入
    cond, ordered_mask_dino = self.get_cond(cond, **kwargs)
    kwargs['ordered_mask_dino'] = ordered_mask_dino
    
    # 5. 模型预测 velocity field
    # 注意:timestep 乘以 1000(因为模型期望 [0, 1000] 范围)
    pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
    
    assert pred.shape == noise.shape == x_0.shape
    
    # 6. 计算 target velocity field
    # v = noise - x_0(Flow Matching 的目标)
    target = self.get_v(x_0, noise, t)
    
    # 7. 计算 MSE loss
    terms = edict()
    terms["mse"] = F.mse_loss(pred.feats, target.feats)
    terms["loss"] = terms["mse"]
    
    # 8. 按 timestep 分 bin 记录损失(用于监控)
    mse_per_instance = np.array([
        F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
        for i in range(x_0.shape[0])
    ])
    time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
    for i in range(10):
        if (time_bin == i).sum() != 0:
            terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
    
    return terms, {}
Flow Matching 目标

Flow Matching 是一种生成模型训练方法,其核心思想是:

  1. 定义概率流:从噪声分布到数据分布的连续路径
  2. 学习速度场:模型预测速度场 v(x_t, t),而不是直接预测 x_0
  3. 训练目标:MSE loss between predicted and true velocity field

数学形式:

  • Path: x_t = (1 - t) * x_0 + t * noise, t ∈ [0, 1]
  • Velocity field: v(x_t, t) = noise - x_0
  • Loss: L = ||v_pred(x_t, t) - v(x_t, t)||²
get_cond 方法(ImageConditionedMixin)

位置:training/trainers/flow_matching/mixins/image_conditioned.py

python 复制代码
def get_cond(self, cond, **kwargs):
    """
    处理图像条件输入
    
    流程:
    1. 使用 DINOv2 编码图像 -> [B, 1374, 1024]
    2. 创建 negative condition(用于 classifier-free guidance)
    3. 调用父类的 get_cond(处理 CFG)
    """
    # 1. 编码图像
    cond = self.encode_image(cond)  # [B, 1374, 1024]
    
    # 2. 创建 negative condition(全零)
    kwargs['neg_cond'] = torch.zeros_like(cond)
    kwargs['neg_mask'] = torch.zeros_like(kwargs['ordered_mask_dino'])
    
    # 3. 调用父类处理 CFG
    cond = super().get_cond(cond, **kwargs)
    
    return cond
encode_image 方法
python 复制代码
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
    """
    使用 DINOv2 编码图像
    
    输入:图像 [B, 3, 518, 518]
    输出:Patch tokens [B, 1374, 1024]
    
    1374 = 37 * 37 (patches) + 5 (special tokens)
    """
    if self.image_cond_model is None:
        self._init_image_cond_model()
    
    # 归一化
    image = image.cuda()
    image = self.image_cond_model['transform'](image)
    
    # DINOv2 forward
    features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
    patchtokens = F.layer_norm(features, features.shape[-1:])
    
    return patchtokens  # [B, 1374, 1024]

关键代码解析

数据流可视化

Loss Model Trainer Dataset Loss Model Trainer Dataset get_instance(uuid) Load SLat data (all_latent.npz) Process parts + noise mask Load image + mask Encode image (DINO) Process mask (order parts) Return batch data training_losses(x_0, cond) Sample noise Sample timestep t Diffuse: x_t = (1-t)*x_0 + t*noise Encode cond (DINO + mask) Forward(x_t, t, cond) Mask embedding Part-wise batch IDs Input blocks (downsample) Transformer blocks (cross-attn) Output blocks (upsample) Return velocity prediction Compute target velocity MSE loss Return loss Backward + Update

关键函数注释

1. collate_fn(Dataset)
python 复制代码
def collate_fn(self, items, split_size=1):
    """
    将多个样本 collate 成一个 batch
    
    关键处理:
    1. 合并所有样本的 coords 和 feats
    2. 更新 batch IDs(coords[:, 0])
    3. 保存 part_layouts(用于模型前向传播)
    """
    # 合并所有样本
    all_coords = []
    all_feats = []
    all_part_layouts = []
    
    batch_id = 0
    for item in items:
        coords = item['coords'].clone()
        feats = item['feats'].clone()
        
        # 更新 batch ID
        coords[:, 0] = batch_id
        
        all_coords.append(coords)
        all_feats.append(feats)
        
        # 保存 part_layouts(需要调整索引)
        adjusted_layouts = []
        offset = sum(len(all_coords) - 1)  # 累计的 voxel 数量
        for layout in item['part_layouts']:
            adjusted_layouts.append(slice(layout.start + offset, 
                                        layout.stop + offset))
        all_part_layouts.append(adjusted_layouts)
        
        batch_id += 1
    
    # 创建 SparseTensor
    coords = torch.cat(all_coords, dim=0)
    feats = torch.cat(all_feats, dim=0)
    x_0 = sp.SparseTensor(feats=feats, coords=coords)
    
    return {
        'x_0': x_0,
        'cond': torch.stack([item['cond'] for item in items]),
        'ordered_mask_dino': torch.stack([item['ordered_mask_dino'] for item in items]),
        'part_layouts': all_part_layouts,
    }
2. diffuse 方法(FlowMatchingTrainer)
python 复制代码
def diffuse(self, x_0, t, noise=None):
    """
    应用 diffusion process
    
    x_t = (1 - t) * x_0 + t * noise
    
    这是 Flow Matching 的核心:定义从数据到噪声的路径
    """
    if noise is None:
        noise = x_0.replace(torch.randn_like(x_0.feats))
    
    # t 需要 broadcast 到每个样本
    t = t.view(-1, *([1] * (x_0.feats.ndim - 1)))
    
    # Linear interpolation
    x_t = x_0.replace((1 - t) * x_0.feats + t * noise.feats)
    
    return x_t
3. get_v 方法(FlowMatchingTrainer)
python 复制代码
def get_v(self, x_0, noise, t):
    """
    计算 target velocity field
    
    v = noise - x_0
    
    这是 Flow Matching 的目标:模型需要预测这个速度场
    """
    return x_0.replace(noise.feats - x_0.feats)

训练配置示例

json 复制代码
{
  "dataset": {
    "name": "StructuredLatentPartDataset",
    "args": {
      "data_root": "/path/to/data",
      "img_data_root": "/path/to/images",
      "noise_mask_score": 1.0,
      "aug_bbox_range": [0, 2],
      "normalization": "standard"
    }
  },
  "models": {
    "denoiser": {
      "name": "StructuredLatentFlow",
      "args": {
        "input_dim": 9,  // 8 (SLat) + 1 (noise_mask_score)
        "hidden_dim": 1024,
        "depth": 24,
        "num_heads": 16,
        "cond_dim": 1024
      }
    }
  },
  "trainer": {
    "name": "ImageConditionedSparseFlowMatchingTrainer",
    "args": {
      "t_schedule": "uniform",
      "sigma_min": 0.0
    }
  }
}

总结

核心设计理念

  1. Part-aware Generation

    • 通过 part_layouts 组织数据
    • Part-wise batch processing 处理不同数量的 parts
  2. Coverage Constraint

    • Noise mask 机制确保完整覆盖
    • Overall shape + Parts 的组合训练
  3. Conditional Generation

    • 图像条件(DINO features)
    • Mask 条件(Part IDs embedding)
  4. Flow Matching

    • 学习速度场而非直接预测
    • 更稳定的训练过程

关键文件索引

  • Dataset : training/datasets/structured_latent_part.py
  • Components : training/datasets/components.py
  • Model : training/models/structured_latent_flow.py
  • Trainer : training/trainers/flow_matching/sparse_flow_matching.py
  • Image Condition : training/trainers/flow_matching/mixins/image_conditioned.py

训练命令

bash 复制代码
python train.py \
    --config configs/training_part_synthesis.json \
    --output_dir ./outputs/part_synthesis \
    --data_dir /path/to/data

本文档基于 OmniPart 代码库编写,详细解释了训练流程的关键设计和技术细节。

相关推荐
xhbaitxl3 小时前
算法学习day38-动态规划
学习·算法·动态规划
江瀚视野3 小时前
多家银行向甲骨文断贷,巨头甲骨文这是怎么了?
大数据·人工智能
历程里程碑3 小时前
普通数组----轮转数组
java·数据结构·c++·算法·spring·leetcode·eclipse
ccLianLian3 小时前
计算机基础·cs336·损失函数,优化器,调度器,数据处理和模型加载保存
人工智能·深度学习·计算机视觉·transformer
pp起床3 小时前
贪心算法 | part02
算法·leetcode·贪心算法
sin_hielo3 小时前
leetcode 1653
数据结构·算法·leetcode
asheuojj3 小时前
2026年GEO优化获客效果评估指南:如何精准衡量TOP5关
大数据·人工智能·python
多恩Stone3 小时前
【RoPE】Flux 中的 Image Tokenization
开发语言·人工智能·python
2501_901147833 小时前
面试必看:优势洗牌
笔记·学习·算法·面试·职场和发展