sam2环境安装

仓库地址:

https://github.com/facebookresearch/sam2

Segment Anything Model 2(SAM 2)是一个面向图像和视频中可提示的可视分割的基础模型。我们将SAM扩展到视频,将图像视为单帧视频。模型设计采用简单的transformer架构,具有流式内存以实现实时视频处理。我们构建了一个模型闭环数据引擎,通过用户交互改进模型和数据,以收集我们的SA-V数据集,这是迄今为止最大的视频分割数据集。在我们在数据上训练的SAM 2提供了在广泛任务和视觉领域中的强大性能。

注意:

建议您通过Anaconda创建一个新的Python环境,然后按照https://pytorch.org/中的说明使用pip安装PyTorch 2.5.1(或更高版本)。如果您当前环境中PyTorch的版本低于2.5.1,上述安装命令将尝试使用pip将其升级到最新的PyTorch版本。

上述步骤需要使用nvcc编译器编译自定义CUDA内核。如果您的机器上还没有,请安装与您的PyTorch CUDA版本匹配的CUDA工具包。

如果您在安装过程中看到类似"在安装过程中构建SAM 2 CUDA扩展失败"的消息,您可以忽略它并仍然使用SAM 2(某些后处理功能可能有限,但在大多数情况下不会影响结果)。

环境配置:

如果已经有Anaconda 或者miniconda,以管理员身份运行Anaconda Prompt

可先创建一个python3.10或更新的虚拟环境:

conda create -n sam2gpu python=3.10

激活环境

activate sam2gpu

建议创建一个英文目录,例如d盘YoloTest下面有一个ForSam,

cd D:\YoloTest\ForSam

再输入d:

此时再按顺序分别输入以下命令,sma2项目就会被下载到ForSam文件夹中

git clone https://github.com/facebookresearch/sam2.git && cd sam2

pip install -e .

官网的示例代码直接运行可能得不出什么结果,以下代码是通过ai完善了一下,可运行

python 复制代码
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))

image_path = "truck.jpg"
image = Image.open(image_path)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(image) # 直接传入PIL Image对象
    # masks, _, _ = predictor.predict(multimask_output=False)
    masks, _, _ = predictor.predict(point_coords=[[500, 375]],
                            point_labels=[1],
                            box=None,
                            mask_input=None,
                            multimask_output=False)

# 将PIL图像转换为numpy数组
image_np = np.array(image)

# 可视化结果
plt.figure(figsize=(10, 10))
plt.imshow(image_np)
plt.imshow(masks[0], alpha=0.5)  # 将mask叠加在原图上
plt.title("Segmentation Result")
plt.axis("off")
plt.show()

# 或者保存结果
result = masks[0].astype(np.uint8) * 255  # 转换为0-255的格式
result_image = Image.fromarray(result)
result_image.save("segmentation_result.png")