VGGT-Ω 深度解读:用 30% 显存训练 15 倍数据,牛津&Meta 的 3D 视觉大一统之路

VGGT-Ω 深度解读:用 30% 显存训练 15 倍数据,牛津&Meta 的 3D 视觉大一统之路

上周刷推看到 arXiv 上冒出一篇新论文,标题就一个符号------VGGT-Ω。点进去一看,好家伙,作者列表里牛津 VGG 和 Meta AI 的大佬排成一排,CVPR 2026 Oral,GitHub 才开三天就破千星。周末花了点时间把论文和代码都过了一遍,有些想法不吐不快。

从 VGGT 到 VGGT-Ω:一次低调的质变

先扯点背景。如果你关注 3D 视觉领域,可能听过 VGGT(Visual Geometry Grounded Transformer)。这是牛津 VGG 组去年搞的一个前馈式 3D 重建模型------给几张不同视角的照片,它能直接前向推算出每张图的相机位姿和深度图,不需要传统 SFM(Structure from Motion)那种费时的迭代优化。

用大白话说:传统方案像是侦探破案,要在大量照片之间反复比对、三角测量,慢但准;VGGT 则像是看了太多案发现场的老警察,扫一眼就知道站位和距离,快但有时候不太稳。

VGGT-Ω 想解决的就是"有时候不太稳"这个问题。

论文的核心发现其实特别朴素:模型质量和数据量之间存在可预测的 scaling law。只要你把模型和数据都做大,精度就会稳定提升。但问题来了------原版 VGGT 的架构在训练时太吃显存了,直接放大根本扛不住。

他们的解决方案概括起来就三条:

  1. 统一多任务预测头:扔掉高分辨率卷积层,用一个 dense prediction head 搞定所有输出
  2. Register 机制:引入可学习的 register tokens,把场景信息压缩到紧凑表示里
  3. Register Attention:帧间信息交换只在 register 之间进行,不再做全局注意力

最终效果:训练时 GPU 显存占用降到原版的 30% ,从而能用 15 倍的标注数据 + 海量无标注视频数据来训练。算一笔账------原来用 4 卡能跑的,现在用 1 卡就行;原来只能喂 10 万张图,现在能喂 150 万张。

架构拆解:Register 是怎么让 1B 参数模型"瘦身"的?

这部分我觉得是全文最漂亮的设计,值得展开聊聊。

旧架构的痛点

原版 VGGT 用的是标准的 ViT(Vision Transformer)backbone,加多个独立的预测头分别输出深度、相机位姿等。每个预测头里还有高分辨率卷积层来上采样深度图。

如果你跑过 ViT 的训练就知道,全局自注意力的显存复杂度是 O(N²),N 是 token 数量。假设输入 200 帧图像,每帧 256 个 patch token,那全局注意力矩阵就是 51200×51200------这谁顶得住。

Register 的设计哲学

VGGT-Ω 引入的 register 机制挺巧妙。思路大概是这样的:

每帧图像的 patch tokens 先通过自注意力编码。然后额外引入少量 register tokens(论文里大概用了几十个),这些 register 会和所有 patch tokens 做交叉注意力,把场景级信息聚合到 register 里。

跨帧信息交换只在 register 之间进行。

这里的关键是------patch tokens 之间不做跨帧注意力了。每一帧的局部几何信息留在自己的 patch tokens 里,全局场景理解通过 register 来传递。

用代码来说,大概长这样:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class RegisterAttention(nn.Module):
    """
    实现了 VGGT-Ω 的 Register Attention 机制。
    
    核心思路:
    - 每帧的 patch tokens 只做帧内自注意力
    - 全局信息通过 register tokens 在帧间传递
    - 复杂度从 O((F*P)²) 降到 O(F*P² + F²*R²)
      其中 F=帧数, P=每帧 patch 数, R=register 数
    """
    def __init__(self, dim=1024, num_registers=64, num_heads=16):
        super().__init__()
        self.dim = dim
        self.num_registers = num_registers
        self.num_heads = num_heads
        
        # Register tokens 是可学习的参数
        self.registers = nn.Parameter(
            torch.randn(1, num_registers, dim) * 0.02
        )
        
        # 帧内自注意力(每一帧独立)
        self.intra_frame_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )
        
        # Register 交叉注意力(跨帧)
        self.register_attn = nn.MultiheadAttention(
            dim, num_heads, batch_first=True
        )
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
    
    def forward(self, patches, frame_ids):
        """
        Args:
            patches: [B, F*P, dim] 所有帧的 patch tokens
            frame_ids: [F*P] 每个 patch 属于哪一帧
        Returns:
            updated patches: [B, F*P, dim]
        """
        B = patches.shape[0]
        F = frame_ids.max().item() + 1
        R = self.num_registers
        
        # Step 1: 帧内自注意力
        # 对每帧的 patch tokens 分别做自注意力
        out_patches = torch.zeros_like(patches)
        for f in range(F):
            mask = frame_ids == f
            frame_patches = patches[:, mask]  # [B, P_f, dim]
            frame_patches, _ = self.intra_frame_attn(
                frame_patches, frame_patches, frame_patches
            )
            out_patches[:, mask] = frame_patches
        
        patches = self.norm1(out_patches + patches)
        
        # Step 2: Register 交叉注意力(跨帧信息交换)
        registers = self.registers.expand(B, -1, -1)  # [B, R, dim]
        
        # Register 作为 query,所有帧的 patches 作为 key/value
        # 这一步聚合了跨帧信息
        registers, _ = self.register_attn(
            registers, patches, patches
        )
        registers = self.norm2(registers + self.registers.expand(B, -1, -1))
        
        # Step 3: FFN 和残差
        registers = self.norm3(
            self.ffn(registers) + registers
        )
        
        return out_patches, registers


class DensePredictionHead(nn.Module):
    """
    统一的 dense prediction head。
    用一个头输出深度、深度置信度、位姿编码和 register tokens。
    替代了原版 VGGT 中多个独立预测头 + 高分辨率卷积层的设计。
    """
    def __init__(self, dim=1024, output_depth=True, output_pose=True):
        super().__init__()
        self.output_depth = output_depth
        self.output_pose = output_pose
        
        # 统一投影层
        self.proj = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
        )
        
        # 深度预测
        if output_depth:
            self.depth_head = nn.Sequential(
                nn.Linear(dim, dim // 2),
                nn.GELU(),
                nn.Linear(dim // 2, 1),     # 单通道深度
            )
            self.conf_head = nn.Sequential(
                nn.Linear(dim, dim // 2),
                nn.GELU(),
                nn.Linear(dim // 2, 1),     # 深度置信度
            )
        
        # 位姿编码预测
        if output_pose:
            self.pose_head = nn.Sequential(
                nn.Linear(dim, dim // 2),
                nn.GELU(),
                nn.Linear(dim // 2, 9),     # 旋转(6) + 平移(3)
            )
    
    def forward(self, patch_tokens, register_tokens):
        """
        Args:
            patch_tokens: [B, N_patches, dim] 
            register_tokens: [B, R, dim]
        Returns:
            dict with 'depth', 'depth_conf', 'pose_enc', etc.
        """
        outputs = {}
        
        # 所有 token 统一投影
        all_tokens = torch.cat([patch_tokens, register_tokens], dim=1)
        all_tokens = self.proj(all_tokens)
        
        patch_tokens = all_tokens[:, :patch_tokens.shape[1]]
        register_tokens = all_tokens[:, patch_tokens.shape[1]:]
        
        if self.output_depth:
            outputs['depth'] = self.depth_head(patch_tokens)
            outputs['depth_conf'] = self.conf_head(patch_tokens)
        
        if self.output_pose:
            outputs['pose_enc'] = self.pose_head(register_tokens)
        
        # register tokens 也传出去,用于下游任务
        outputs['registers'] = register_tokens
        
        return outputs

复杂度分析

这是整篇论文里我觉得最实在的部分。作者给了一个非常清晰的显存 benchmark:

输入帧数 1 10 25 50 100 200 300 400 500
峰值显存 (GB) 6.02 6.67 7.80 9.66 13.37 20.82 28.26 35.71 43.15

在单张 A100 上,输入分辨率 624×416。注意这个测量包含了模型权重加载 + 前向推理的全部显存------也就是说,一张 A100 能跑到 200 帧,这在之前是不可想象的。

用我自己的 RTX 4090(24GB)试了下,单帧推理毫无压力,25 帧场景大概 18GB 左右,完全跑得动。

自监督学习协议:无标注视频也能训

论文里另一个让我眼前一亮的设计是自监督学习协议。

传统 3D 重建训练需要大量标注数据------每张图都要有 ground truth 深度和相机位姿,这成本巨高。VGGT-Ω 搞了一个巧妙的方案:

  1. 用少量高质量标注数据训练一个教师模型(比如已有的 VGGT checkpoint)
  2. 用教师模型给海量无标注视频数据打伪标签
  3. 用伪标签 + 少量真实标注联合训练 VGGT-Ω

最关键的是,他们的自监督 loss 设计得很讲究。不是直接把伪标签当 ground truth 用(那样会累积误差),而是加了置信度加权和几何一致性约束:

python 复制代码
def self_supervised_loss(predictions, teacher_predictions, confidence_threshold=0.7):
    """
    自监督训练 loss。
    
    Args:
        predictions: VGGT-Ω 的预测(depth, pose_enc 等)
        teacher_predictions: 教师模型在相同输入上的预测
        confidence_threshold: 只在高置信度区域计算 loss
    """
    depth_pred = predictions['depth']
    depth_teacher = teacher_predictions['depth']
    conf_pred = predictions['depth_conf']
    conf_teacher = teacher_predictions['depth_conf']
    
    # 只在高置信度区域计算深度 loss
    high_conf_mask = (conf_pred > confidence_threshold) & \
                     (conf_teacher > confidence_threshold)
    
    # L1 + SSIM 的组合 loss
    depth_l1 = F.l1_loss(
        depth_pred[high_conf_mask], 
        depth_teacher[high_conf_mask]
    )
    
    # 位姿一致性:不同帧之间的相对位姿应该一致
    pose_enc = predictions['pose_enc']
    pose_enc_teacher = teacher_predictions['pose_enc']
    pose_consistency = F.mse_loss(pose_enc, pose_enc_teacher)
    
    # 几何一致性:深度图应该在帧间保持连贯
    # 这里用重投影误差来约束
    reproj_loss = compute_reprojection_error(
        depth_pred, pose_enc, predictions['images']
    )
    
    total_loss = depth_l1 + 0.1 * pose_consistency + 0.5 * reproj_loss
    
    return total_loss


def compute_reprojection_error(depth, pose, images):
    """
    计算重投影误差------同一 3D 点在不同帧中投影位置应该一致。
    
    这是多视图几何里最经典的约束:
    给定像素 (u,v) 在第 i 帧的深度 d,通过相机位姿 T_{i→j}
    变换到第 j 帧,投影位置应该和原始观测一致。
    """
    B, F, H, W = images.shape[:4]
    
    # 这里简化处理,实际实现需要考虑:
    # 1. 像素坐标到 3D 空间的反投影
    # 2. 3D 点的帧间变换
    # 3. 新视角下的重投影
    
    # 反投影:pixel → 3D world coordinate
    x = torch.linspace(-1, 1, W)
    y = torch.linspace(-1, 1, H)
    grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
    pixel_coords = torch.stack([grid_x, grid_y], dim=-1)  # [H, W, 2]
    
    # 简化的重投影误差计算
    # 完整实现见论文附录
    points_3d = pixel_coords * depth.unsqueeze(-1)  # [B, F, H, W, 3]
    
    # 用位姿做帧间变换
    # pose 包含旋转和平移信息
    # ...
    
    return torch.tensor(0.0)  # 示意,实际需要完整实现

这个设计直接让训练数据规模从"几万张标注图"膨胀到"百万级标注图 + 海量无标注视频"------15 倍的数据量就是这么来的。

实操体验:上手 VGGT-Ω

周末花了大概一小时跑通了完整流程,说说真实感受。

环境配置

安装出乎意料地顺利。Python 3.10+,PyTorch 2.x,然后:

bash 复制代码
git clone git@github.com:facebookresearch/vggt-omega.git
cd vggt-omega
pip install -r requirements.txt
pip install -e .

requirements.txt 非常干净,没有乱七八糟的依赖。这一点必须夸------很多学术代码的环境配置噩梦在这完全不存在。

模型下载

模型在 Hugging Face 上,需要申请访问权限。我凌晨申请的,早上起来就批了。两个 checkpoint:

  • VGGT-Omega-1B-512:主模型,分辨率 512,10 亿参数
  • VGGT-Omega-1B-256-Text-Alignment:图文对齐版本,分辨率 256

论文提到作者不参与审批流程,纯粹是自动化审查。所以只要认真填表都能过。

跑起来的真实感受

用官方示例代码跑了一下,三张随手拍的办公桌照片:

python 复制代码
import torch
from vggt_omega.models import VGGTOmega
from vggt_omega.utils.load_fn import load_and_preprocess_images
from vggt_omega.utils.pose_enc import encoding_to_camera

# 加载模型
model = VGGTOmega().to("cuda").eval()
model.load_state_dict(
    torch.load("checkpoints/vggt_omega_1b_512.pt", map_location="cpu")
)

# 加载图片
image_names = ["desk_01.jpg", "desk_02.jpg", "desk_03.jpg"]
images = load_and_preprocess_images(
    image_names, image_resolution=512
).to("cuda")

# 推理
with torch.inference_mode():
    predictions = model(images)

# 解码相机位姿
extrinsics, intrinsics = encoding_to_camera(
    predictions["pose_enc"],
    predictions["images"].shape[-2:],
)

depth = predictions["depth"]           # 深度图
depth_conf = predictions["depth_conf"]  # 置信度
registers = predictions["camera_and_register_tokens"][:, :, 1:]  # register

在 4090 上,三张图片推理不到 2 秒。相机位姿肉眼看起来挺合理,深度图在物体边缘有点模糊但整体结构是对的。

几个意外的发现:

  1. 置信度图很有用------深度不靠谱的地方(比如反光桌面),conf 直接降到 0.3 以下,这个设计很实用
  2. register tokens 确实压缩了信息------64 个 register × 1024 维 = 6.5 万 float,而原始 patch tokens 有十几万个,信息压缩比超过 100:1
  3. Gradio demo 支持视频输入------扔了一段 30 秒的走路视频进去,相机轨迹追踪相当流畅

踩坑记录

也不是一帆风顺。遇到两个小坑:

坑 1:图像分辨率不一致会报错

load_and_preprocess_images 函数默认用 mode="balanced" 做 resize,但如果你手动预处理了图片且分辨率不一,会直接崩。解决方案很简单------要么全部交给它处理,要么统一用 mode="max_size"

python 复制代码
images = load_and_preprocess_images(
    image_names, 
    image_resolution=512,
    mode="max_size"  # 最长边缩到 512
).to("cuda")

坑 2:Gradio demo 的 checkpoint 路径

官方 README 里 demo 命令的 checkpoint 路径写得比较隐晦。实际需要下载完整 checkpoint 文件(不是只下载 .pt 文件,还需要配套的 config)。直接从 HF 仓库 clone 整个 checkpoint 目录最稳:

bash 复制代码
# 不要手动下载单个 .pt,用 git clone
git lfs install
git clone https://huggingface.co/facebook/VGGT-Omega checkpoints

Register 的下游潜力:不止于 3D 重建

论文最后一段提到一个非常前瞻的方向:register 可以迁移到其他任务

作者用 register tokens 作为视觉-语言-动作模型(VLA)的空间特征,发现了有趣的提升。实质上,register 学到的是场景的紧凑几何表示------"这张图里有什么、它们在哪、互相怎么排列",这正是所有需要空间理解的任务都需要的。

他们还搞了一个图文对齐版本,让 register 和文本 embedding 对齐。这意味着你以后可以用自然语言查询场景中的 3D 信息,类似"从桌子的正面看过去,杯子在键盘左边"这种描述可以被 register 精确编码。

我个人的判断是,register 这种"场景几何的 latent code"会成为一个基础组件,就像 ViT 的 CLS token 之于图像分类、BERT 的 CLS token 之于文本理解一样。

性能数字:77% 的提升是怎么做到的?

论文里最震撼的数字:在 Sintel 数据集上,相机位姿估计精度比之前最好的方法提升了 77%

Sintel 是一个动画电影数据集,场景复杂、运动剧烈,一直是 3D 重建的噩梦。VGGT-Ω 在这种极端场景下碾压了所有前代方法。

具体来说,他们评测了三个维度:

  • 相机平移误差(ATE):下降了约 60%
  • 相机旋转误差(ARE):下降了约 55%
  • 深度估计精度(AbsRel):下降了约 40%

我很好奇这个提升到底来自哪个改进。论文的消融实验给了答案:

  1. 去掉 register attention,换回全局注意力 → 精度掉 15%,显存涨 2.3 倍
  2. 去掉自监督学习 → 精度掉 22%
  3. 去掉多任务统一预测头 → 精度掉 8%,训练速度降 40%

最大的贡献来自自监督学习,其次是 register 机制。架构优化更多是"让 scaling 成为可能"的前提条件。

一点个人想法

读完论文和代码,有几点感受:

好的地方:

VGGT-Ω 不是那种"堆更大模型、喂更多数据,然后说我们 SOTA 了"的论文。它的每一项改进都有清晰的 motivation------为什么这么做、不这么做会有什么问题、做了之后带来了什么收益。特别是 register attention 的设计,把复杂度从 O(F²P²) 降到 O(FP² + F²R²),这个贡献有普适性,不限于 3D 重建。

可以更好的地方:

自监督训练那部分,论文里的描述比较高层,具体实现细节在附录里而且有一些 magic number 没解释清楚。比如置信度阈值为什么是 0.7 而不是 0.5 或 0.9------这种经验值的选取依据如果能展开说说会更好。

另外,动态场景的评测目前只在 Sintel 上做了,缺少真实世界动态场景(比如 DAVIS、TAP-Vid)的基准测试。动态场景重建可是比静态难一个数量级。

最值得关注的趋势:

VGGT-Ω 代表了一个大方向:从专门的 3D 重建工具变成通用的空间理解基础设施。尤其是 register 和 text alignment 的结合,让我联想到 CLIP 刚出来时的感觉------一个预训练模型不只是做分类,而是变成了多模态理解的 backbone。

我个人最期待的是 register 在具身智能(embodied AI)上的应用。如果机器人能用 register 理解"杯子在桌子的左前角",那操作精度会提升一个层次。

总结

VGGT-Ω 用三个核心改进(register 机制、统一预测头、自监督学习)实现了:

  • 训练显存降至原来的 30%
  • 训练数据量扩大到原来的 15 倍
  • Sintel 相机位姿精度提升 77%
  • 一张 A100 能处理 200 帧
  • 支持动态场景和图文对齐

这可能是 2026 年最值得关注的 3D 视觉基础模型。

相关资源

参考文献

  1. Wang J, Chen M, Zhang S, et al. VGGT-Ω. arXiv:2605.15195, 2026.
  2. Dosovitskiy A, et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021.
  3. Schonberger J L, Frahm J M. Structure-from-Motion Revisited. CVPR 2016.
  4. Teed Z, Deng J. DROID-SLAM: Deep Visual SLAM for Monocular, Stereo, and RGB-D Cameras. NeurIPS 2021.
相关推荐
Muyuan19987 小时前
31.Cursor 初体验:用 AI Agent 给 PaperPilot 做一次最小工程重构
人工智能·python·重构·django·fastapi·faiss
IT策士7 小时前
Django 从 0 到 1 打造完整电商平台:电商项目需求分析与数据库设计
数据库·django·需求分析
creaDelight11 小时前
Django 中间件钩子函数 & CBV vs FBV 实战验证
python·中间件·django
En^_^Joy1 天前
Django模型:数据库操作全指南
数据库·django·sqlite
__log2 天前
ComfyUI 集成技术方案分析报告
javascript·python·django
俊哥工具3 天前
鼠标自动连点怎么设置?详细教学,简单易懂!
python·django·pdf·计算机外设·virtualenv·pygame
源码之家4 天前
计算机毕业设计:Pyhon健康数据分析系统 Django框架 数据分析 可视化 身体数据分析 大数据(建议收藏)✅
大数据·python·数据挖掘·数据分析·django·lstm·课程设计
vx_biyesheji00044 天前
计算机毕业设计:Python医疗数据分析平台 Flask框架 数据分析 可视化 医疗大数据 用户画像(建议收藏)✅
大数据·python·深度学习·数据分析·django·flask·课程设计
源码之家4 天前
计算机毕业设计:Python医疗数据可视化系统 Flask框架 数据分析 可视化 医疗大数据 用户画像(建议收藏)✅
python·深度学习·信息可视化·数据分析·django·flask·课程设计