系列文章目录
- 【3D AICG 系列-1】Trellis v1 和 Trellis v2 的区别和改进
- 【3D AICG 系列-2】Trellis 2 的O-voxel (上) Shape: Flexible Dual Grid
- 【3D AICG 系列-3】Trellis 2 的O-voxel (下) Material: Volumetric Surface Attributes
- 【3D AICG 系列-4】Trellis 2 的Shape SLAT Flow Matching DiT 训练流程
- 【3D AICG 系列-5】Trellis 2 的 Pipeline 推理流程的各个中间结果和形状
文章目录
- 系列文章目录
- [OmniPart 训练代码技术详解](#OmniPart 训练代码技术详解)
-
- 目录
- 训练流程概述
- 数据处理流程
-
- [Dataset 类:`StructuredLatentPartDataset`](#Dataset 类:
StructuredLatentPartDataset) -
- 核心数据结构
- [`get_instance` 方法详解](#
get_instance方法详解) - [Noise Mask 机制的作用](#Noise Mask 机制的作用)
- [图像和 Mask 条件处理](#图像和 Mask 条件处理)
-
- `ImageConditionedMixin.get_instance`
- [`load_bottom_up_mask` 方法](#
load_bottom_up_mask方法)
- [Dataset 类:`StructuredLatentPartDataset`](#Dataset 类:
- 模型架构
- 训练循环
-
- [Trainer 类结构](#Trainer 类结构)
-
- `SparseFlowMatchingTrainer`
- [`training_losses` 方法](#
training_losses方法) - [Flow Matching 目标](#Flow Matching 目标)
- [`get_cond` 方法(ImageConditionedMixin)](#
get_cond方法(ImageConditionedMixin)) - [`encode_image` 方法](#
encode_image方法)
- 关键代码解析
- 总结
OmniPart 训练代码技术详解
本文档详细解析 OmniPart 的训练代码逻辑,包括数据处理流程、模型架构、训练循环和关键设计点。
目录
训练流程概述
整体架构
OmniPart 是一个基于 Flow Matching 的 part-aware 3D 生成模型,其训练流程如下:
模型前向传播
训练阶段
数据加载阶段
Part Mesh
Voxelize
SLat Encoder
SLat Features
Render Images
DINO Encoder
Image Features
Mask Files
Mask Processing
Ordered Mask
Dataset
Collate Batch
Flow Matching Loss
Backward Pass
Update Model
Sparse Tensor
Mask Embedding
Part-wise Batch
Transformer Blocks
Velocity Prediction
训练阶段说明
OmniPart 的训练分为两个主要阶段:
-
Stage 1: Sparse Structure Flow
- 学习从图像到稀疏结构的映射
- 输出:稀疏的 3D 结构表示
-
Stage 2: Structured Latent Flow (Part-aware)
- 学习 part-level 的 SLat 生成
- 输入:图像 + Mask
- 输出:Part-level SLat features
本文档主要关注 Stage 2 的训练流程。
数据处理流程
Dataset 类:StructuredLatentPartDataset
位置:training/datasets/structured_latent_part.py
核心数据结构
python
# 数据文件结构:{uuid[:2]}/{uuid}/all_latent.npz
{
'coords': [N, 3], # Voxel 坐标
'feats': [N, 8], # SLat features (8维)
'offsets': [num_parts+1] # 每个 part 的起始索引
}
get_instance 方法详解
这是数据加载的核心方法,负责处理 part-level 数据:
python
def get_instance(self, uuid):
"""
处理所有 parts 和整体模型数据,合并为单一输出
关键步骤:
1. 加载 all_latent.npz(包含 overall + 所有 parts)
2. 对每个 part,计算 bbox 并添加 noise mask
3. 为 part 和 overall 添加 noise_mask_score 标记
4. 返回合并后的坐标、特征和 part_layouts
"""
# 1. 加载数据
data = np.load(os.path.join(self.data_root, uuid[:2], uuid, 'all_latent.npz'))
all_coords = data['coords'].astype(np.int32)
all_feats = data['feats'].astype(np.float32)
offsets = data['offsets'].astype(np.int32)
all_coords_wnoise = []
all_feats_wnoise = []
part_layouts = []
start_idx = 0
# 2. 处理每个 part(i=0 是 overall,i>0 是 parts)
for i in range(len(offsets) - 1):
if i == 0:
# Overall shape:直接添加,标记为 positive
overall_coords = all_coords[offsets[i]:offsets[i+1]]
overall_feats = all_feats[offsets[i]:offsets[i+1]]
overall_ids = self.coords_to_ids(overall_coords)
part_layouts.append(slice(start_idx, start_idx + overall_coords.shape[0]))
start_idx += overall_coords.shape[0]
else:
# Part:计算 bbox,添加 noise mask
part_coords = all_coords[offsets[i]:offsets[i+1]]
part_feats = all_feats[offsets[i]:offsets[i+1]]
# 计算 part 的 bbox
part_bbox_min = part_coords.min(axis=0)
part_bbox_max = part_coords.max(axis=0)
# Bbox augmentation(可选)
if self.aug_bbox_range is not None:
# 随机扩展 bbox 范围,增加训练鲁棒性
aug_bbox_min = np.random.randint(...)
aug_bbox_max = np.random.randint(...)
part_bbox_min = np.clip(part_bbox_min - aug_bbox_min, 0, 63)
part_bbox_max = np.clip(part_bbox_max + aug_bbox_max, 0, 63)
# 关键:Noise Mask 机制
# 找到在 part bbox 内但不在 part 内的 voxels(这些是 noise)
bbox_mask = np.all((overall_coords >= part_bbox_min) &
(overall_coords <= part_bbox_max), axis=1)
part_ids = self.coords_to_ids(part_coords)
part_in_overall = np.isin(overall_ids, part_ids)
noise_mask = np.logical_and(bbox_mask, np.logical_not(part_in_overall))
# 提取 noise voxels
noise_coords = overall_coords[noise_mask]
noise_feats = overall_feats[noise_mask]
# 添加 noise_mask_score:
# - Part voxels: +noise_mask_score (positive)
# - Noise voxels: -noise_mask_score (negative)
noise_feats = np.concatenate((noise_feats,
np.full((noise_feats.shape[0], 1),
-self.noise_mask_score)), axis=1)
part_feats = np.concatenate((part_feats,
np.full((part_feats.shape[0], 1),
self.noise_mask_score)), axis=1)
# 合并 part 和 noise
part_coords = np.concatenate((part_coords, noise_coords), axis=0)
part_feats = np.concatenate((part_feats, noise_feats), axis=0)
part_layouts.append(slice(start_idx, start_idx + part_coords.shape[0]))
start_idx += part_coords.shape[0]
all_coords_wnoise.append(torch.from_numpy(part_coords))
all_feats_wnoise.append(torch.from_numpy(part_feats))
# 3. 添加 overall shape(标记为 positive)
overall_feats = np.concatenate((overall_feats,
np.full((overall_feats.shape[0], 1),
self.noise_mask_score)), axis=1)
all_coords_wnoise.insert(0, torch.from_numpy(overall_coords))
all_feats_wnoise.insert(0, torch.from_numpy(overall_feats))
# 4. 合并所有数据
combined_coords = torch.cat(all_coords_wnoise, dim=0).int()
combined_feats = torch.cat(all_feats_wnoise, dim=0).float()
# 5. 归一化(如果配置了)
if self.normalization is not None:
combined_feats = (combined_feats - self.mean) / self.std
return {
'coords': combined_coords,
'feats': combined_feats,
'part_layouts': part_layouts, # 每个 part 的 slice,用于 batch 处理
}
Noise Mask 机制的作用
Noise Mask 是 OmniPart 的关键设计,用于实现 Coverage Loss 的隐式约束:
- 防止 Part 缺失:通过添加 overall shape,确保模型学习完整的几何覆盖
- 区分 Part 和背景:通过 noise_mask_score 标记,模型学习区分真实 part 和背景噪声
- Bbox 约束:只考虑 part bbox 内的 voxels,避免全局噪声干扰
图像和 Mask 条件处理
位置:training/datasets/components.py
ImageConditionedMixin.get_instance
python
def get_instance(self, uuid, select_method='random_seg'):
"""
加载图像和 mask 作为条件
流程:
1. 随机选择一个 view(0-25)
2. 加载图像:color_{index:04d}.webp
3. 加载 mask:mask_{index:04d}.exr
4. 处理 mask:排序 parts(从下到上),下采样到 DINO patch size
"""
pack = super().get_instance(uuid)
# 随机选择 view
data_dir = os.path.join(self.img_data_root, uuid[:2], uuid)
random_index = random.choice([0, 1, 2, ..., 25])
image_path = os.path.join(data_dir, f"color_{random_index:04d}.webp")
mask_path = os.path.join(data_dir, f"mask_{random_index:04d}.exr")
# 加载并处理 mask
mask, mask_vis, ordered_mask_dino = self.load_bottom_up_mask(mask_path, size=(518, 518))
pack['ordered_mask_dino'] = ordered_mask_dino # [37, 37] 用于 DINO
# 加载图像
image, image_rgb = self.load_image(image_path)
pack['cond'] = image # [3, 518, 518]
return pack
load_bottom_up_mask 方法
python
def load_bottom_up_mask(self, mask_path, size=(518, 518)):
"""
加载 mask 并按从下到上的顺序排序 parts
关键步骤:
1. 加载原始 mask(EXR 格式)
2. 下采样到 DINO patch size (37x37)
3. 根据 y 坐标排序 parts(从下到上)
4. 重新映射 part IDs
"""
# 1. 加载 mask
mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
mask = cv2.resize(mask, size, interpolation=cv2.INTER_NEAREST)
# 2. 下采样到 DINO patch size (37x37)
mask_dino = cv2.resize(mask, (37, 37), interpolation=cv2.INTER_NEAREST)
mask_dino = np.array(mask_dino, dtype=np.int32)[..., 0]
mask_org = np.array(mask, dtype=np.int32)[..., 0]
# 3. 找到所有 part IDs
unique_indices = np.unique(mask_org)
unique_indices = unique_indices[unique_indices > 0] # 排除背景 (0)
# 4. 计算每个 part 的 y 坐标(最下方的 y 值)
part_positions = {}
for idx in unique_indices:
y_coords, _ = np.where(mask_org == idx)
if len(y_coords) > 0:
part_positions[idx] = np.max(y_coords) # 最下方的 y
# 5. 按 y 坐标从大到小排序(从下到上)
sorted_parts = sorted(part_positions.items(), key=lambda x: -x[1])
# 6. 重新映射 part IDs(1, 2, 3, ...)
index_map = {}
for new_idx, (old_idx, _) in enumerate(sorted_parts, 1):
index_map[old_idx] = new_idx
# 7. 创建有序的 mask
ordered_mask = np.zeros_like(mask_org)
ordered_mask_dino = np.zeros_like(mask_dino)
for old_idx, new_idx in index_map.items():
ordered_mask[mask_org == old_idx] = new_idx
ordered_mask_dino[mask_dino == old_idx] = new_idx
ordered_mask_dino = torch.from_numpy(ordered_mask_dino).long()
return mask, mask_vis, ordered_mask_dino # [37, 37] for DINO
为什么需要排序?
- 确保 part 的顺序一致性(从下到上)
- 便于模型学习 part 之间的空间关系
- 与 part_layouts 中的 part 顺序对应
模型架构
StructuredLatentFlow 模型
位置:training/models/structured_latent_flow.py
整体架构
Input SparseTensor
Input Layer
Mask Embedding
Part-wise Batch IDs
Input Blocks
Downsampling
Transformer Blocks
Cross-Attention
Output Blocks
Upsampling
Output Layer
Velocity Prediction
Image Features
Mask Features
Forward 方法详解
python
def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor, **kwargs):
"""
前向传播流程
Args:
x: SparseTensor [N, 9] (8维 SLat + 1维 noise_mask_score)
t: Timestep [B]
cond: DINO features [B, 1374, 1024]
kwargs:
- ordered_mask_dino: [B, 37, 37]
- part_layouts: List[List[slice]]
Returns:
Output SparseTensor [N, 8] (velocity prediction)
"""
input_dtype = x.dtype
# ========== 1. Mask Embedding ==========
masks = kwargs['ordered_mask_dino'] # [B, 37, 37]
masks = masks.long()
masks = rearrange(masks, 'b h w -> b (h w)') # [B, 1369]
# Embedding: part ID -> 128维 -> 1024维
masks_emb = self.group_embedding(masks) # [B, 1369, 128]
masks_emb = self.group_emb_proj(masks_emb) # [B, 1369, 1024]
# 对齐到 cond 的长度(1374 = 1369 + 5 special tokens)
group_emb = torch.zeros((cond.shape[0], cond.shape[1], masks_emb.shape[2]),
device=cond.device, dtype=cond.dtype)
group_emb[:, :masks_emb.shape[1], :] = masks_emb
# 融合 mask embedding 到 DINO features
cond = cond + group_emb # [B, 1374, 1024]
cond = cond.type(self.dtype)
# ========== 2. Part-wise Batch Processing ==========
original_batch_ids = x.coords[:, 0].clone() # 原始 batch IDs
# 创建新的 batch IDs:每个 part 一个独立的 batch ID
new_batch_ids = torch.zeros_like(original_batch_ids)
part_layouts = kwargs['part_layouts']
part_id = 0
len_before = 0
batch_last_partid = []
for batch_idx, part_layout in enumerate(part_layouts):
# part_layout 是 List[slice],每个 slice 对应一个 part
for layout_idx, layout in enumerate(part_layout):
adjusted_layout = slice(layout.start + len_before,
layout.stop + len_before,
layout.step)
new_batch_ids[adjusted_layout] = part_id
part_id += 1
batch_last_partid.append(part_id)
len_before += part_layout[-1].stop
# 更新 coordinates:使用 part-wise batch IDs
x = sp.SparseTensor(
feats=x.feats,
coords=torch.cat([new_batch_ids.view(-1, 1), x.coords[:, 1:]], dim=1)
)
# ========== 3. Input Processing ==========
x = self.input_layer(x).type(self.dtype) # Project to model dim
# Timestep embedding
t_emb = self.t_embedder(t) # [B, hidden_dim]
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = t_emb.type(self.dtype)
# 为每个 part 复制 timestep embedding
t_emb_updown = []
for batch_idx, part_layout in enumerate(part_layouts):
t_emb_updown_batch = t_emb[batch_idx:batch_idx+1].repeat(len(part_layout), 1)
t_emb_updown.append(t_emb_updown_batch)
t_emb_updown = torch.cat(t_emb_updown, dim=0).type(self.dtype)
# ========== 4. Downsampling (Input Blocks) ==========
skips = []
for block in self.input_blocks:
x = block(x, t_emb_updown)
skips.append(x.feats) # 保存 skip connections
# ========== 5. Transformer Blocks ==========
part_wise_batch_ids = x.coords[:, 0].clone()
# 转换回 batch-wise IDs(用于 transformer)
new_transformer_batch_ids = torch.zeros_like(part_wise_batch_ids)
part_ids_in_each_object = torch.zeros_like(part_wise_batch_ids)
start_reform = 0
last_part_id = 0
for part_id in batch_last_partid:
mask = (part_wise_batch_ids >= last_part_id) & (part_wise_batch_ids < part_id)
new_transformer_batch_ids[mask] = start_reform
part_ids_in_each_object[mask] = part_wise_batch_ids[mask] - last_part_id
last_part_id = part_id
start_reform += 1
# 更新 coordinates:使用 batch-wise IDs
h = sp.SparseTensor(
feats=x.feats,
coords=torch.cat([new_transformer_batch_ids.view(-1, 1), x.coords[:, 1:]], dim=1)
)
# 添加位置编码
if self.pe_mode == "ape":
# Absolute positional encoding
h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
# Part-wise positional encoding(overall 是 0,parts 是 1, 2, 3, ...)
part_pe = self.layer_pe(part_ids_in_each_object)
part_pe = self.layer_pe_proj(part_pe)
h = h + part_pe.type(self.dtype)
# Transformer blocks with cross-attention
for block in self.blocks:
h = block(h, t_emb, cond) # cond 用于 cross-attention
# ========== 6. Upsampling (Output Blocks) ==========
# 恢复 part-wise batch IDs
h = x.replace(feats=h.feats,
coords=torch.cat([part_wise_batch_ids.view(-1, 1), h.coords[:, 1:]], dim=1))
# Upsampling with skip connections
for block, skip in zip(self.out_blocks, reversed(skips)):
if self.use_skip_connection:
h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb_updown)
else:
h = block(h, t_emb_updown)
# ========== 7. Output ==========
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h.type(input_dtype))
# 恢复原始 batch IDs
h = sp.SparseTensor(
feats=h.feats,
coords=torch.cat([original_batch_ids.view(-1, 1), h.coords[:, 1:]], dim=1)
)
return h
关键设计点
-
Mask Embedding
- 将 mask 的 part IDs 转换为 embedding
- 与 DINO features 相加,提供 part 信息
-
Part-wise Batch Processing
- 每个 part 独立处理(独立的 batch ID)
- 在 transformer 阶段转换回 batch-wise
- 便于处理不同数量的 parts
-
Part Positional Encoding
- Overall shape: 0
- Parts: 1, 2, 3, ...(按从下到上排序)
- 帮助模型理解 part 的层次关系
训练循环
Trainer 类结构
位置:training/trainers/flow_matching/sparse_flow_matching.py
SparseFlowMatchingTrainer
继承关系:
Trainer (base.py)
└── FlowMatchingTrainer (flow_matching.py)
└── SparseFlowMatchingTrainer (sparse_flow_matching.py)
└── ImageConditionedSparseFlowMatchingTrainer (with ImageConditionedMixin)
training_losses 方法
python
def training_losses(
self,
x_0: sp.SparseTensor,
cond=None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
计算 Flow Matching 训练损失
流程:
1. 生成随机噪声(与 x_0 相同的 sparsity pattern)
2. 采样随机 timestep t
3. 应用 diffusion:x_t = (1-t) * x_0 + t * noise
4. 预测 velocity field
5. 计算 MSE loss
"""
# 1. 生成噪声(保持相同的 sparsity pattern)
noise = x_0.replace(torch.randn_like(x_0.feats))
# 2. 采样随机 timestep t ~ U(0, 1)
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
# 3. 应用 diffusion process
# x_t = (1 - t) * x_0 + t * noise
x_t = self.diffuse(x_0, t, noise=noise)
# 4. 处理条件输入
cond, ordered_mask_dino = self.get_cond(cond, **kwargs)
kwargs['ordered_mask_dino'] = ordered_mask_dino
# 5. 模型预测 velocity field
# 注意:timestep 乘以 1000(因为模型期望 [0, 1000] 范围)
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
assert pred.shape == noise.shape == x_0.shape
# 6. 计算 target velocity field
# v = noise - x_0(Flow Matching 的目标)
target = self.get_v(x_0, noise, t)
# 7. 计算 MSE loss
terms = edict()
terms["mse"] = F.mse_loss(pred.feats, target.feats)
terms["loss"] = terms["mse"]
# 8. 按 timestep 分 bin 记录损失(用于监控)
mse_per_instance = np.array([
F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
for i in range(x_0.shape[0])
])
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
for i in range(10):
if (time_bin == i).sum() != 0:
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
return terms, {}
Flow Matching 目标
Flow Matching 是一种生成模型训练方法,其核心思想是:
- 定义概率流:从噪声分布到数据分布的连续路径
- 学习速度场:模型预测速度场 v(x_t, t),而不是直接预测 x_0
- 训练目标:MSE loss between predicted and true velocity field
数学形式:
- Path: x_t = (1 - t) * x_0 + t * noise, t ∈ [0, 1]
- Velocity field: v(x_t, t) = noise - x_0
- Loss: L = ||v_pred(x_t, t) - v(x_t, t)||²
get_cond 方法(ImageConditionedMixin)
位置:training/trainers/flow_matching/mixins/image_conditioned.py
python
def get_cond(self, cond, **kwargs):
"""
处理图像条件输入
流程:
1. 使用 DINOv2 编码图像 -> [B, 1374, 1024]
2. 创建 negative condition(用于 classifier-free guidance)
3. 调用父类的 get_cond(处理 CFG)
"""
# 1. 编码图像
cond = self.encode_image(cond) # [B, 1374, 1024]
# 2. 创建 negative condition(全零)
kwargs['neg_cond'] = torch.zeros_like(cond)
kwargs['neg_mask'] = torch.zeros_like(kwargs['ordered_mask_dino'])
# 3. 调用父类处理 CFG
cond = super().get_cond(cond, **kwargs)
return cond
encode_image 方法
python
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
使用 DINOv2 编码图像
输入:图像 [B, 3, 518, 518]
输出:Patch tokens [B, 1374, 1024]
1374 = 37 * 37 (patches) + 5 (special tokens)
"""
if self.image_cond_model is None:
self._init_image_cond_model()
# 归一化
image = image.cuda()
image = self.image_cond_model['transform'](image)
# DINOv2 forward
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens # [B, 1374, 1024]
关键代码解析
数据流可视化
Loss Model Trainer Dataset Loss Model Trainer Dataset get_instance(uuid) Load SLat data (all_latent.npz) Process parts + noise mask Load image + mask Encode image (DINO) Process mask (order parts) Return batch data training_losses(x_0, cond) Sample noise Sample timestep t Diffuse: x_t = (1-t)*x_0 + t*noise Encode cond (DINO + mask) Forward(x_t, t, cond) Mask embedding Part-wise batch IDs Input blocks (downsample) Transformer blocks (cross-attn) Output blocks (upsample) Return velocity prediction Compute target velocity MSE loss Return loss Backward + Update
关键函数注释
1. collate_fn(Dataset)
python
def collate_fn(self, items, split_size=1):
"""
将多个样本 collate 成一个 batch
关键处理:
1. 合并所有样本的 coords 和 feats
2. 更新 batch IDs(coords[:, 0])
3. 保存 part_layouts(用于模型前向传播)
"""
# 合并所有样本
all_coords = []
all_feats = []
all_part_layouts = []
batch_id = 0
for item in items:
coords = item['coords'].clone()
feats = item['feats'].clone()
# 更新 batch ID
coords[:, 0] = batch_id
all_coords.append(coords)
all_feats.append(feats)
# 保存 part_layouts(需要调整索引)
adjusted_layouts = []
offset = sum(len(all_coords) - 1) # 累计的 voxel 数量
for layout in item['part_layouts']:
adjusted_layouts.append(slice(layout.start + offset,
layout.stop + offset))
all_part_layouts.append(adjusted_layouts)
batch_id += 1
# 创建 SparseTensor
coords = torch.cat(all_coords, dim=0)
feats = torch.cat(all_feats, dim=0)
x_0 = sp.SparseTensor(feats=feats, coords=coords)
return {
'x_0': x_0,
'cond': torch.stack([item['cond'] for item in items]),
'ordered_mask_dino': torch.stack([item['ordered_mask_dino'] for item in items]),
'part_layouts': all_part_layouts,
}
2. diffuse 方法(FlowMatchingTrainer)
python
def diffuse(self, x_0, t, noise=None):
"""
应用 diffusion process
x_t = (1 - t) * x_0 + t * noise
这是 Flow Matching 的核心:定义从数据到噪声的路径
"""
if noise is None:
noise = x_0.replace(torch.randn_like(x_0.feats))
# t 需要 broadcast 到每个样本
t = t.view(-1, *([1] * (x_0.feats.ndim - 1)))
# Linear interpolation
x_t = x_0.replace((1 - t) * x_0.feats + t * noise.feats)
return x_t
3. get_v 方法(FlowMatchingTrainer)
python
def get_v(self, x_0, noise, t):
"""
计算 target velocity field
v = noise - x_0
这是 Flow Matching 的目标:模型需要预测这个速度场
"""
return x_0.replace(noise.feats - x_0.feats)
训练配置示例
json
{
"dataset": {
"name": "StructuredLatentPartDataset",
"args": {
"data_root": "/path/to/data",
"img_data_root": "/path/to/images",
"noise_mask_score": 1.0,
"aug_bbox_range": [0, 2],
"normalization": "standard"
}
},
"models": {
"denoiser": {
"name": "StructuredLatentFlow",
"args": {
"input_dim": 9, // 8 (SLat) + 1 (noise_mask_score)
"hidden_dim": 1024,
"depth": 24,
"num_heads": 16,
"cond_dim": 1024
}
}
},
"trainer": {
"name": "ImageConditionedSparseFlowMatchingTrainer",
"args": {
"t_schedule": "uniform",
"sigma_min": 0.0
}
}
}
总结
核心设计理念
-
Part-aware Generation
- 通过 part_layouts 组织数据
- Part-wise batch processing 处理不同数量的 parts
-
Coverage Constraint
- Noise mask 机制确保完整覆盖
- Overall shape + Parts 的组合训练
-
Conditional Generation
- 图像条件(DINO features)
- Mask 条件(Part IDs embedding)
-
Flow Matching
- 学习速度场而非直接预测
- 更稳定的训练过程
关键文件索引
- Dataset :
training/datasets/structured_latent_part.py - Components :
training/datasets/components.py - Model :
training/models/structured_latent_flow.py - Trainer :
training/trainers/flow_matching/sparse_flow_matching.py - Image Condition :
training/trainers/flow_matching/mixins/image_conditioned.py
训练命令
bash
python train.py \
--config configs/training_part_synthesis.json \
--output_dir ./outputs/part_synthesis \
--data_dir /path/to/data
本文档基于 OmniPart 代码库编写,详细解释了训练流程的关键设计和技术细节。