SAM2(Segment Anything Model 2)是由Meta公司开发的一款先进的图像和视频分割模型,它是Segment Anything Model(SAM)的升级版本。与第一代相比,SAM2在多个方面实现了显著的改进,包括
:
-
支持视频分割:SAM2的一个重要进展是它的能力从图像分割扩展到了视频分割,能够处理视频中的对象,而不仅仅是静态图像。
-
实时处理任意长视频:SAM2能够实时处理任意长度的视频,这在实际应用中非常有用,尤其是在需要快速响应的场景中。
-
Zero-shot泛化:即使是在视频中没有见过的对象,SAM2也能实现有效的分割和追踪,这显示了其强大的泛化能力。
-
分割和追踪准确性提升:与第一代模型相比,SAM2在分割和追踪准确性方面有了显著提升。
-
统一模型架构:SAM2在单一模型中结合了图像和视频分割功能,简化了部署,并能在不同媒体类型中实现一致的性能。
-
实时性能:该模型可实现实时推理速度,每秒处理约44帧图像,适用于需要即时反馈的应用,如视频编辑和增强现实。
-
互动式改进:用户可以通过提供额外的提示来迭代完善分割结果,从而实现对输出的精确控制。
-
视觉挑战的高级处理:SAM2能够处理常见的视频分割难题,如物体遮挡和重现。
SAM2的这些特性使其成为一个强大的工具,适用于各种图像和视频分割任务,尤其是在需要实时处理和高精度的场景中。此外,SAM2的开源特性也促进了其在研究和工业界的广泛应用。
完整版代码下载地址:完整代码
其中segment-anything-2-main是原始版代码,fine-tune-train_segment_anything_2_in_60_lines_of_code-main是训练自己数据集的代码。
配置SAM2环境
首先创建conda环境
conda create --name SAM python=3.11
创建完环境后激活环境,安装pytorch,注意torch与torchvision的安装版本需要满足:"torch>=2.3.1", "torchvision>=0.18.1",安装pytorch的时候,电脑中安装的cuda版本要和pytorch版本对应上。
我的配置是4070显卡,Ubuntu20.04,cuda版本是11.8,所以我要安装的pytorch版本命令是:
pip install torch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu118
该命令从pytorch官网上找到,官网地址:pytorch
接下来开始安装SAM2所需要的包,下载我提供的代码,在终端中打开segment-anything-2-main这个地址,然后运行以下命令安装所需要的包。
pip install --no-build-isolation -e .
准备数据集
SAM2数据集请按照这个教程准备:数据集准备
训练数据集
打开我提供的训练代码,文件名是fine-tune-train_segment_anything_2_in_60_lines_of_code-main
运行TRAIN.py。这个训练文件只需要设置好data_dir,和模型名称就可以,这里选择小模型来训练。
sam2_checkpoint = "checkpoints/sam2_hiera_tiny.pt" # path to model weight (pre model loaded from: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)
model_cfg = "sam2_hiera_t.yaml" # model config
注意,训练完成后会将模型保存为model.torch。
import numpy as np
import torch
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Read data
data_dir=r"VOC2007//" # Path to dataset (LabPics 1)
data=[] # list of files in dataset
for ff, name in enumerate(os.listdir(data_dir+"Train/Image/")): # go over all folder annotation
data.append({"image":data_dir+"Train/Image/"+name,"annotation":data_dir+"Train/Instance/"+name[:-4]+".png"})
def read_batch(data): # read random image and its annotaion from the dataset (LabPics)
# select image
ent = data[np.random.randint(len(data))] # choose random entry
Img = cv2.imread(ent["image"])[...,::-1] # read image OpenCV读取的图像是BGR格式,而SAM期望RGB格式的图像,使用[...,::-1]将图像从BGR转换为RGB。
ann_map = cv2.imread(ent["annotation"]) # read annotation
# resize image
r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor
Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
# merge vessels and materials annotations
mat_map = ann_map[:,:,0] # material annotation map
ves_map = ann_map[:,:,2] # vessel annotaion map
mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merge maps
# Get binary masks and points
inds = np.unique(mat_map)[1:] # load all indices # 地图中所有索引的列表
points= [] # 所有点的列表(每个掩码一个)
masks = [] # 所有掩码的列表
for ind in inds:
mask=(mat_map == ind).astype(np.uint8) # make binary mask corresponding to index ind # 为索引ind制作二进制掩码
masks.append(mask)
coords = np.argwhere(mask > 0) # get all coordinates in mask # 获取掩码中的所有坐标
yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
points.append([[yx[1], yx[0]]])
return Img,np.array(masks),np.array(points), np.ones([len(masks),1])
# Load model
sam2_checkpoint = "checkpoints/sam2_hiera_tiny.pt" # path to model weight (pre model loaded from: https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)
model_cfg = "sam2_hiera_t.yaml" # model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model)
# Set training parameters
predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
'''
#The main part of the net is the image encoder, if you have good GPU you can enable training of this part by using:
predictor.model.image_encoder.train(True)
#Note that for this case, you will also need to scan the SAM2 code for "no_grad" commands and remove them (" no_grad" blocks the gradient collection, which saves memory but prevents training).
'''
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler() # mixed precision
# Training loop
for itr in range(100000):
with torch.cuda.amp.autocast(): # cast to mix precision
image,mask,input_point, input_label = read_batch(data) # load data batch
if mask.shape[0]==0: continue # ignore empty batches
predictor.set_image(image) # apply SAM image encoder to the image
# prompt encoding
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)
# mask decoder
batched_mode = unnorm_coords.shape[0] > 1 # multi object prediction
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,)
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution
# Segmentaion Loss caclulation
gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss
# Score loss calculation (intersection over union) IOU
inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
loss=seg_loss+score_loss*0.05 # mix losses
# apply back propogation
predictor.model.zero_grad() # empty gradient
scaler.scale(loss).backward() # Backpropogate
scaler.step(optimizer)
scaler.update() # Mix precision
if itr%10000==0: torch.save(predictor.model.state_dict(), "model.torch");print("save model")
# Display results
if itr==0: mean_iou=0
mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
print("step)",itr, "Accuracy(IOU)=",mean_iou)
推理预测
运行TEST_Net.py,填写对应的图片和标签地址即可。
为什么推理需要填写标签地址?
这里解释一下,标签的作用是随机取一个点,用来告诉模型分割哪一块区域。SAM分割需要提示,一个点,一个方框或者全部分割,不明白的话可以看一下论文。
代码中的这一句注释掉,推理的时候用的就是原始的权重,不注释掉用的就是微调过后的权重model.torch
predictor.model.load_state_dict(torch.load("model.torch")) #注释掉这一句,用的就是原始模型
# segment image region using fine tune model
# See Train.py on how to fine tune/train the model
import numpy as np
import torch
import cv2
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# use bfloat16 for the entire script (memory efficient)
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# Load image
image_path = r"/home/cdg/test/SAM2/fine-tune-train_segment_anything_2_in_60_lines_of_code-main/VOC2007/Train/Image/ISIC_0000000.jpg" # path to image
mask_path = r"/home/cdg/test/SAM2/fine-tune-train_segment_anything_2_in_60_lines_of_code-main/VOC2007/Train/Instance/ISIC_0000000.png" # path to mask, the mask will define the image region to segment
def read_image(image_path, mask_path): # read and resize image and mask
img = cv2.imread(image_path)[...,::-1] # read image as rgb
mask = cv2.imread(mask_path,0) # mask of the region we want to segment
# Resize image to maximum size of 1024
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
return img, mask
image,mask = read_image(image_path, mask_path)
num_samples = 30 # number of points/segment to sample
def get_points(mask,num_points): # Sample points inside the input mask
points=[]
for i in range(num_points):
coords = np.argwhere(mask > 0)
yx = np.array(coords[np.random.randint(len(coords))])
points.append([[yx[1], yx[0]]])
return np.array(points)
input_points = get_points(mask,num_samples)
# read image and sample points
# Load model you need to have pretrained model already made
sam2_checkpoint = "checkpoints/sam2_hiera_tiny.pt" # "sam2_hiera_large.pt"
model_cfg = "sam2_hiera_t1.yaml" # "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
# Build net and load weights
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load("model.torch")) #注释掉这一句,用的就是原始模型
# predict mask
with torch.no_grad():
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=np.ones([input_points.shape[0],1])
)
# Short predicted masks from high to low score
masks=masks[:,0].astype(bool)
shorted_masks = masks[np.argsort(scores[:,0])][::-1].astype(bool)
# Stitch predicted mask into one segmentation mask
seg_map = np.zeros_like(shorted_masks[0],dtype=np.uint8)
occupancy_mask = np.zeros_like(shorted_masks[0],dtype=bool)
for i in range(shorted_masks.shape[0]):
mask = shorted_masks[i]
if (mask*occupancy_mask).sum()/mask.sum()>0.15: continue
mask[occupancy_mask]=0
seg_map[mask]=i+1
occupancy_mask[mask]=1
# create colored annotation map
height, width = seg_map.shape
# Create an empty RGB image for the colored annotation
rgb_image = np.zeros((seg_map.shape[0], seg_map.shape[1], 3), dtype=np.uint8)
for id_class in range(1,seg_map.max()+1):
rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]
# save and display
cv2.imwrite("annotation.png",rgb_image)
cv2.imwrite("mix.png",(rgb_image/2+image/2).astype(np.uint8))
cv2.imshow("annotation",rgb_image)
cv2.imshow("mix",(rgb_image/2+image/2).astype(np.uint8))
cv2.imshow("image",image)
cv2.waitKey()
参考:
60行代码就可以训练/微调 Segment Anything 2 (SAM 2)_segment anything 训练-CSDN博客