计算机视觉|超详细!Meta视觉大模型Segment Anything(SAM)源码解剖

一、引言

在计算机视觉领域,图像分割是一个核心且具有挑战性的任务,旨在将图像中的不同物体或区域进行划分和识别,广泛应用于自动驾驶、医学影像分析、安防监控等领域。Segment Anything Model(SAM)由 Meta AI 实验室发布,其引入了基于 Prompt 的交互式分割能力,显著提升了图像分割的灵活性和泛化能力。

SAM 通过在海量且多样化的数据集上训练,具备处理未见过对象类别和场景的能力。这一特性使其在学术界引发广泛研究,如模型轻量化、领域适应、多模态融合等方向;在工业界也迅速应用于医学图像分析中的肿瘤检测、遥感图像处理中的卫星图像分析以及视频处理中的目标跟踪等领域。

对于计算机视觉爱好者和开发者而言,深入剖析 SAM 的源码有助于理解其核心技术原理,为进一步创新和应用提供基础。本文将从原理到源码逐步解析 SAM 的工作机制,探索其实现高效图像编码、提示编码及掩码解码的过程。

二、SAM 原理速览

(一)核心概念

SAM 的核心在于其 "分割一切" 的能力,这一能力依赖于基于 Prompt 的分割策略。Prompt 可以是点(points)、框(boxes)、掩码(masks)或文本(text),为模型提供分割目标的关键信息。例如,点击图像中的一个点,SAM 能够识别并分割该点所在物体;给定一个框,SAM 则专注于框内物体的分割。

(二)模型架构

SAM 的架构包含三个核心组件:Image Encoder(图像编码器)Prompt Encoder(提示编码器)Mask Decoder(掩码解码器) ,其协作方式如下图所示:

  • Image Encoder :将输入图像转换为高维特征表示,通常使用预训练的 Vision Transformer(ViT),如 ViT-H、ViT-L、ViT-B。例如,对于分辨率为 1024 × 1024 1024 \times 1024 1024×1024 的图像,经过 Patch Embedding 操作划分为 16 × 16 16 \times 16 16×16 的 patches,特征图尺寸缩小为原来的 1 16 \frac{1}{16} 161,通道数从 3 映射到 768。
  • Prompt Encoder:处理不同类型的 Prompt,将其编码为与图像嵌入兼容的特征。对点和框使用位置编码(Positional Encoding);对文本使用 CLIP 文本编码器;对掩码则通过轻量级卷积网络编码。
  • Mask Decoder:融合图像嵌入和提示嵌入,通过 Transformer 结构解码为分割掩码,默认生成 3 个候选掩码并按置信度排序。

三、源码结构总览

(一)代码目录解析

SAM 的代码目录结构如下:

复制代码
segment-anything/
├── assets              # 示例图片等资源
├── demo                # 前端部署代码
├── notebooks           # Jupyter Notebook 示例
├── script              # 模型导出脚本
├── segment_anything    # 核心代码目录
│   ├── build_sam.py    # 模型构建脚本
│   ├── config.py       # 配置文件
│   ├── mask_decoder.py # 掩码解码器实现
│   ├── model_registry.py # 模型注册模块
│   ├── predictor.py    # 预测接口
│   ├── sam.py          # SAM 整体结构
│   ├── sam_arch.py     # 架构细节
│   ├── utils.py        # 工具函数
│   └── automatic_mask_generator.py # 自动掩码生成
└── setup.py            # 安装脚本

segment_anything 是核心目录,后续分析将聚焦于此。

(二)关键文件与模块

  • build_sam.py:定义模型构建函数,支持不同版本(如 vit_h、vit_l、vit_b)。示例代码:
python 复制代码
def build_sam_vit_h(checkpoint=None):
      # 构建 vit_h 版本的 SAM 模型
      return _build_sam(
          encoder_embed_dim=1280,           # 编码器嵌入维度
          encoder_depth=32,                 # 编码器层数
          encoder_num_heads=16,             # 注意力头数
          encoder_global_attn_indexes=[7, 15, 23, 31],  # 全局注意力层索引
          checkpoint=checkpoint,            # 预训练权重文件路径
      )
  • predictor.py :提供预测接口,set_image 处理图像预处理,predict 根据提示生成掩码。核心代码:
python 复制代码
def set_image(self, image: np.ndarray) -> None:
   # 检查输入图像维度和通道数是否符合要求
   if image.ndim != 3 or image.shape[2] not in [3, 4]:
       raise ValueError("Image must be 3D with 3 or 4 channels")
   # 应用图像变换(如缩放、归一化)
   input_image = self.transform.apply_image(image)
   # 转换为 PyTorch 张量并调整维度为 [1, C, H, W]
   input_image_torch = torch.as_tensor(input_image, device=self.device)
   self.set_torch_image(input_image_torch.permute(2, 0, 1)[None, :, :, :], image.shape[:2])
  • automatic_mask_generator.py:自动生成所有物体掩码,基于点提示网格。核心代码:
python 复制代码
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
   # 预处理输入图像
   input_image = self.model.preprocess(image)
   # 使用无梯度计算图像嵌入
   with torch.no_grad():
       image_embedding = self.model.image_encoder(input_image)
   # 生成点提示网格
   points = self._generate_points(image.shape[:2])
   all_masks, all_scores = [], []
   # 按批次处理点提示并预测掩码
   for i in range(0, len(points), self.points_per_batch):
       batch_points = points[i:i + self.points_per_batch]
       batch_masks, batch_scores = self._predict_masks(image_embedding, batch_points)
       all_masks.extend(batch_masks)
       all_scores.extend(batch_scores)
   # 返回掩码和对应置信度列表
   return [{"segmentation": m, "score": s} for m, s in zip(all_masks, all_scores)]

四、核心代码深度剖析

(一)Image Encoder 源码解析

  • Patch Embedding:将图像划分为 patches 并映射为特征向量:
python 复制代码
class PatchEmbed(nn.Module):
   def __init__(self, kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768):
       # 初始化 Patch Embedding 模块
       super().__init__()
       # 定义卷积层,将图像划分为 patches 并映射到嵌入维度
       self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size, stride)
   
   def forward(self, x: torch.Tensor) -> torch.Tensor:
       # 对输入图像进行卷积操作,生成特征图
       x = self.proj(x)
       # 调整维度顺序为 [B, H/16, W/16, C],适配 Transformer 输入
       return x.permute(0, 2, 3, 1)

输入图像 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W] 转换为 [ B , H 16 , W 16 , 768 ] [B, \frac{H}{16}, \frac{W}{16}, 768] [B,16H,16W,768]。

  • Transformer Encoder:堆叠多个 Transformer Block 提取特征:
python 复制代码
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio):
        # 初始化 Transformer Block
        super().__init__()
        # 第一层归一化
        self.norm1 = nn.LayerNorm(dim)
        # 多头注意力模块
        self.attn = Attention(dim, num_heads)
        # 第二层归一化
        self.norm2 = nn.LayerNorm(dim)
        # 多层感知机,隐藏层维度为 dim * mlp_ratio
        self.mlp = Mlp(dim, int(dim * mlp_ratio))
    
    def forward(self, x):
        # 自注意力计算并残差连接
        x = x + self.attn(self.norm1(x))
        # MLP 计算并残差连接
        x = x + self.mlp(self.norm2(x))
        return x

(二)Prompt Encoder 源码解析

编码不同类型提示:

python 复制代码
class PromptEncoder(nn.Module):
    def __init__(self, embed_dim, image_embedding_size, input_image_size):
        # 初始化 Prompt Encoder 模块
        super().__init__()
        # 位置编码层,使用随机高斯矩阵生成
        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
        # 提示嵌入层,处理点和框提示
        self.prompt_embed_layer = PromptEmbedding(embed_dim, input_image_size, image_embedding_size)
    
    def forward(self, points=None, boxes=None, masks=None):
        # 初始化稀疏嵌入张量,维度为 [batch_size, 0, embed_dim]
        sparse_embed = torch.zeros(bs, 0, self.embed_dim, device=self.device)
        if points:
            # 提取点提示的坐标和标签
            coords, labels = points
            # 生成点嵌入并拼接到稀疏嵌入中
            sparse_embed = torch.cat([sparse_embed, self.prompt_embed_layer.point_embedding(coords, labels)], dim=1)
        # 返回稀疏嵌入和密集嵌入(未完全展示 masks 处理部分)
        return sparse_embed, dense_embed

五、实战演练

(一)环境搭建与配置

  1. 安装 Python:版本 ≥ 3.8。

  2. 安装 PyTorch

    bash 复制代码
    pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/cu117
  3. 安装 SAM

    bash 复制代码
    pip install -U "git+https://github.com/facebookresearch/segment-anything.git"

(二)代码运行与结果分析

示例代码:

python 复制代码
from segment_anything import sam_model_registry, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt

# 加载 vit_b 版本的 SAM 模型并移动到 GPU
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth").to("cuda")
# 初始化预测器
predictor = SamPredictor(sam)
# 读取图像并转换为 RGB 格式
image = cv2.cvtColor(cv2.imread("test_image.jpg"), cv2.COLOR_BGR2RGB)
# 设置输入图像,进行预处理和特征提取
predictor.set_image(image)
# 定义点提示坐标和标签(1 表示前景)
masks, scores, _ = predictor.predict(point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True)

# 遍历掩码和分数,显示结果
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.imshow(image)  # 显示原始图像
    plt.imshow(mask, alpha=0.6)  # 显示掩码,透明度为 0.6
    plt.title(f"Mask {i+1}, Score: {score:.3f}")  # 设置标题,显示掩码编号和置信度
    plt.show()  # 显示图像

结果分析 :SAM 根据点提示生成多个掩码,分数反映置信度。高分掩码通常更准确,低分掩码可能包含错误分割。

六、总结与展望

SAM 通过高效的图像编码、提示编码和掩码解码实现了灵活的图像分割。未来,其在医学影像、自动驾驶和视频处理中的应用潜力巨大。技术发展方向包括模型轻量化、多模态融合和领域适应,为计算机视觉带来更多可能性。


延伸阅读


相关推荐
@小匠17 小时前
Read Frog:一款开源的 AI 驱动浏览器语言学习扩展
人工智能·学习
山间小僧19 小时前
「AI学习笔记」RNN
机器学习·aigc·ai编程
网教盟人才服务平台20 小时前
“方班预备班盾立方人才培养计划”正式启动!
大数据·人工智能
芯智工坊20 小时前
第15章 Mosquitto生产环境部署实践
人工智能·mqtt·开源
菜菜艾20 小时前
基于llama.cpp部署私有大模型
linux·运维·服务器·人工智能·ai·云计算·ai编程
TDengine (老段)21 小时前
TDengine IDMP 可视化 —— 分享
大数据·数据库·人工智能·时序数据库·tdengine·涛思数据·时序数据
小真zzz21 小时前
搜极星:第三方多平台中立GEO洞察专家全面解析
人工智能·搜索引擎·seo·geo·中立·第三方平台
GreenTea1 天前
从 Claw-Code 看 AI 驱动的大型项目开发:2 人 + 10 个自治 Agent 如何产出 48K 行 Rust 代码
前端·人工智能·后端
火山引擎开发者社区1 天前
秒级创建实例,火山引擎 Milvus Serverless 让 AI Agent 开发更快更省
人工智能