计算机视觉|超详细!Meta视觉大模型Segment Anything(SAM)源码解剖

一、引言

在计算机视觉领域,图像分割是一个核心且具有挑战性的任务,旨在将图像中的不同物体或区域进行划分和识别,广泛应用于自动驾驶、医学影像分析、安防监控等领域。Segment Anything Model(SAM)由 Meta AI 实验室发布,其引入了基于 Prompt 的交互式分割能力,显著提升了图像分割的灵活性和泛化能力。

SAM 通过在海量且多样化的数据集上训练,具备处理未见过对象类别和场景的能力。这一特性使其在学术界引发广泛研究,如模型轻量化、领域适应、多模态融合等方向;在工业界也迅速应用于医学图像分析中的肿瘤检测、遥感图像处理中的卫星图像分析以及视频处理中的目标跟踪等领域。

对于计算机视觉爱好者和开发者而言,深入剖析 SAM 的源码有助于理解其核心技术原理,为进一步创新和应用提供基础。本文将从原理到源码逐步解析 SAM 的工作机制,探索其实现高效图像编码、提示编码及掩码解码的过程。

二、SAM 原理速览

(一)核心概念

SAM 的核心在于其 "分割一切" 的能力,这一能力依赖于基于 Prompt 的分割策略。Prompt 可以是点(points)、框(boxes)、掩码(masks)或文本(text),为模型提供分割目标的关键信息。例如,点击图像中的一个点,SAM 能够识别并分割该点所在物体;给定一个框,SAM 则专注于框内物体的分割。

(二)模型架构

SAM 的架构包含三个核心组件:Image Encoder(图像编码器)Prompt Encoder(提示编码器)Mask Decoder(掩码解码器) ,其协作方式如下图所示:

  • Image Encoder :将输入图像转换为高维特征表示,通常使用预训练的 Vision Transformer(ViT),如 ViT-H、ViT-L、ViT-B。例如,对于分辨率为 1024 × 1024 1024 \times 1024 1024×1024 的图像,经过 Patch Embedding 操作划分为 16 × 16 16 \times 16 16×16 的 patches,特征图尺寸缩小为原来的 1 16 \frac{1}{16} 161,通道数从 3 映射到 768。
  • Prompt Encoder:处理不同类型的 Prompt,将其编码为与图像嵌入兼容的特征。对点和框使用位置编码(Positional Encoding);对文本使用 CLIP 文本编码器;对掩码则通过轻量级卷积网络编码。
  • Mask Decoder:融合图像嵌入和提示嵌入,通过 Transformer 结构解码为分割掩码,默认生成 3 个候选掩码并按置信度排序。

三、源码结构总览

(一)代码目录解析

SAM 的代码目录结构如下:

复制代码
segment-anything/
├── assets              # 示例图片等资源
├── demo                # 前端部署代码
├── notebooks           # Jupyter Notebook 示例
├── script              # 模型导出脚本
├── segment_anything    # 核心代码目录
│   ├── build_sam.py    # 模型构建脚本
│   ├── config.py       # 配置文件
│   ├── mask_decoder.py # 掩码解码器实现
│   ├── model_registry.py # 模型注册模块
│   ├── predictor.py    # 预测接口
│   ├── sam.py          # SAM 整体结构
│   ├── sam_arch.py     # 架构细节
│   ├── utils.py        # 工具函数
│   └── automatic_mask_generator.py # 自动掩码生成
└── setup.py            # 安装脚本

segment_anything 是核心目录,后续分析将聚焦于此。

(二)关键文件与模块

  • build_sam.py:定义模型构建函数,支持不同版本(如 vit_h、vit_l、vit_b)。示例代码:
python 复制代码
def build_sam_vit_h(checkpoint=None):
      # 构建 vit_h 版本的 SAM 模型
      return _build_sam(
          encoder_embed_dim=1280,           # 编码器嵌入维度
          encoder_depth=32,                 # 编码器层数
          encoder_num_heads=16,             # 注意力头数
          encoder_global_attn_indexes=[7, 15, 23, 31],  # 全局注意力层索引
          checkpoint=checkpoint,            # 预训练权重文件路径
      )
  • predictor.py :提供预测接口,set_image 处理图像预处理,predict 根据提示生成掩码。核心代码:
python 复制代码
def set_image(self, image: np.ndarray) -> None:
   # 检查输入图像维度和通道数是否符合要求
   if image.ndim != 3 or image.shape[2] not in [3, 4]:
       raise ValueError("Image must be 3D with 3 or 4 channels")
   # 应用图像变换(如缩放、归一化)
   input_image = self.transform.apply_image(image)
   # 转换为 PyTorch 张量并调整维度为 [1, C, H, W]
   input_image_torch = torch.as_tensor(input_image, device=self.device)
   self.set_torch_image(input_image_torch.permute(2, 0, 1)[None, :, :, :], image.shape[:2])
  • automatic_mask_generator.py:自动生成所有物体掩码,基于点提示网格。核心代码:
python 复制代码
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
   # 预处理输入图像
   input_image = self.model.preprocess(image)
   # 使用无梯度计算图像嵌入
   with torch.no_grad():
       image_embedding = self.model.image_encoder(input_image)
   # 生成点提示网格
   points = self._generate_points(image.shape[:2])
   all_masks, all_scores = [], []
   # 按批次处理点提示并预测掩码
   for i in range(0, len(points), self.points_per_batch):
       batch_points = points[i:i + self.points_per_batch]
       batch_masks, batch_scores = self._predict_masks(image_embedding, batch_points)
       all_masks.extend(batch_masks)
       all_scores.extend(batch_scores)
   # 返回掩码和对应置信度列表
   return [{"segmentation": m, "score": s} for m, s in zip(all_masks, all_scores)]

四、核心代码深度剖析

(一)Image Encoder 源码解析

  • Patch Embedding:将图像划分为 patches 并映射为特征向量:
python 复制代码
class PatchEmbed(nn.Module):
   def __init__(self, kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768):
       # 初始化 Patch Embedding 模块
       super().__init__()
       # 定义卷积层,将图像划分为 patches 并映射到嵌入维度
       self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size, stride)
   
   def forward(self, x: torch.Tensor) -> torch.Tensor:
       # 对输入图像进行卷积操作,生成特征图
       x = self.proj(x)
       # 调整维度顺序为 [B, H/16, W/16, C],适配 Transformer 输入
       return x.permute(0, 2, 3, 1)

输入图像 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W] 转换为 [ B , H 16 , W 16 , 768 ] [B, \frac{H}{16}, \frac{W}{16}, 768] [B,16H,16W,768]。

  • Transformer Encoder:堆叠多个 Transformer Block 提取特征:
python 复制代码
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio):
        # 初始化 Transformer Block
        super().__init__()
        # 第一层归一化
        self.norm1 = nn.LayerNorm(dim)
        # 多头注意力模块
        self.attn = Attention(dim, num_heads)
        # 第二层归一化
        self.norm2 = nn.LayerNorm(dim)
        # 多层感知机,隐藏层维度为 dim * mlp_ratio
        self.mlp = Mlp(dim, int(dim * mlp_ratio))
    
    def forward(self, x):
        # 自注意力计算并残差连接
        x = x + self.attn(self.norm1(x))
        # MLP 计算并残差连接
        x = x + self.mlp(self.norm2(x))
        return x

(二)Prompt Encoder 源码解析

编码不同类型提示:

python 复制代码
class PromptEncoder(nn.Module):
    def __init__(self, embed_dim, image_embedding_size, input_image_size):
        # 初始化 Prompt Encoder 模块
        super().__init__()
        # 位置编码层,使用随机高斯矩阵生成
        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
        # 提示嵌入层,处理点和框提示
        self.prompt_embed_layer = PromptEmbedding(embed_dim, input_image_size, image_embedding_size)
    
    def forward(self, points=None, boxes=None, masks=None):
        # 初始化稀疏嵌入张量,维度为 [batch_size, 0, embed_dim]
        sparse_embed = torch.zeros(bs, 0, self.embed_dim, device=self.device)
        if points:
            # 提取点提示的坐标和标签
            coords, labels = points
            # 生成点嵌入并拼接到稀疏嵌入中
            sparse_embed = torch.cat([sparse_embed, self.prompt_embed_layer.point_embedding(coords, labels)], dim=1)
        # 返回稀疏嵌入和密集嵌入(未完全展示 masks 处理部分)
        return sparse_embed, dense_embed

五、实战演练

(一)环境搭建与配置

  1. 安装 Python:版本 ≥ 3.8。

  2. 安装 PyTorch

    bash 复制代码
    pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/cu117
  3. 安装 SAM

    bash 复制代码
    pip install -U "git+https://github.com/facebookresearch/segment-anything.git"

(二)代码运行与结果分析

示例代码:

python 复制代码
from segment_anything import sam_model_registry, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt

# 加载 vit_b 版本的 SAM 模型并移动到 GPU
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth").to("cuda")
# 初始化预测器
predictor = SamPredictor(sam)
# 读取图像并转换为 RGB 格式
image = cv2.cvtColor(cv2.imread("test_image.jpg"), cv2.COLOR_BGR2RGB)
# 设置输入图像,进行预处理和特征提取
predictor.set_image(image)
# 定义点提示坐标和标签(1 表示前景)
masks, scores, _ = predictor.predict(point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True)

# 遍历掩码和分数,显示结果
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.imshow(image)  # 显示原始图像
    plt.imshow(mask, alpha=0.6)  # 显示掩码,透明度为 0.6
    plt.title(f"Mask {i+1}, Score: {score:.3f}")  # 设置标题,显示掩码编号和置信度
    plt.show()  # 显示图像

结果分析 :SAM 根据点提示生成多个掩码,分数反映置信度。高分掩码通常更准确,低分掩码可能包含错误分割。

六、总结与展望

SAM 通过高效的图像编码、提示编码和掩码解码实现了灵活的图像分割。未来,其在医学影像、自动驾驶和视频处理中的应用潜力巨大。技术发展方向包括模型轻量化、多模态融合和领域适应,为计算机视觉带来更多可能性。


延伸阅读


相关推荐
cyong88819 分钟前
深度学习中的向量的样子-DCN
人工智能·深度学习
Python数据分析与机器学习28 分钟前
《基于深度学习的高分卫星图像配准模型研发与应用》开题报告
图像处理·人工智能·python·深度学习·神经网络·机器学习
BineHello2 小时前
强化学习 - PPO控制无人机
人工智能·算法·自动驾驶·动态规划·无人机·强化学习
牛不才2 小时前
ChatPromptTemplate的使用
人工智能·ai·语言模型·chatgpt·prompt·aigc·openai
从零开始学习人工智能2 小时前
深度学习模型压缩:非结构化剪枝与结构化剪枝的定义与对比
人工智能·深度学习·剪枝
訾博ZiBo2 小时前
AI日报 - 2025年3月18日
人工智能
新说一二3 小时前
AI技术学习笔记系列004:GPU常识
人工智能·笔记·学习
一个处女座的程序猿O(∩_∩)O3 小时前
人工智能中神经网络是如何进行预测的
人工智能·深度学习·神经网络
小白的高手之路3 小时前
如何安装旧版本的Pytorch
人工智能·pytorch·python·深度学习·机器学习·conda
虾球xz3 小时前
游戏引擎学习第161天
人工智能·学习·游戏引擎