使用sam进行零样本、零学习的分割实践

参照:利用SAM实现自动标注_sam标注-CSDN博客,以及SAM(分割一切模型)的简单调用_sam使用-CSDN博客

sam简介:

Segment Anything Model(SAM)是Meta公司于2023年发布的一种AI模型,它打破了现有的分割技术,通过解决任意输入图像的分割问题,开启了发现世界中对象和区域可能性的新时代。 SAM对人类数字世界和整个视觉应用生态系统有着深远的影响,推动软件与用户的互动方式达到了新的高度。

Segment Anything Model的主要特点包括:

**1. 卓越的准确性:**在最为复杂的COCO panoptic基准上,SAM以超过56%的mask quality领先于其他方法。模型的准确性与之前的最佳方法相比,提高了30%,并且比亚军模型高出10个百分点。

**2. 即插即用:**模型以一种即插即用的方式提供,不需要重新训练或微调,同时能够泛化到之前未见过的图像和任务。

**3. 遮罩调优能力:**用户可以使用单个点击(称为点提示)来校正或微调模型的初始分割结果。

**4. 多模式提示:**模型能够采用来自多种模式的输入,包括复选框、排除框和标记点,用于优化分割。

**5. 灵活性:**支持处理各种任务,从细粒度的边界到粗糙的分割。

**6. 速度:**仅需约0.3秒/帧的速度,在快速和准确之间取得了平衡,与现有的交互式分割方法相比,显著减少了时间需求。

**7. 强大的泛化能力:**SAM 在不同的数据集和应用场景下都表现出了良好的泛化能力,不仅可以处理各种类型的图像,对于一些它未见过或相对模糊的场景,同样能实现较好的图像分割效果。

总的来说,SAM提供了一种高效的分割解决方案,具有高度的灵活性和准确性,同时不需要像传统的分割方法那样进行大量的微调和训练。

本文通过一个简单的图片分割实践来体验一下sam在非标图片的零样本、零训练下的分割能力。

一、前期准备

1、下载或克隆项目,打开下面的网址,可以克隆或者下载zip文件,这里选择了下载zip文件:GitCode - 全球开发者的开源社区,开源代码托管平台

得到了文件:segment-anything-main.zip

在这个网页中,还可以看到环境的要求:

2、按照上面的要求在conda虚拟环境下安装pytorch和torchvision。这里要注意,虽然SAM没有直接对CUDA版本提出要求, 但是pytorch和CUDA的版本是紧密相关的。

选择了:

pytorch=1.7.1

torchvision=0.8.0

python=3.8

cuda=11.0

cudnn=8.0.5

具体的版本选择方法见:人工智能学习用的电脑安装cuda、torch、conda等软件,版本的选择以及多版本切换_cuda版本要求-CSDN博客

如果使用在线的方式:Previous PyTorch Versions | PyTorch

根据选定的版本下载和安装相关的torch、cuda、conda等工具和软件。

3、安装必要的依赖包

pip install opencv-python pycocotools matplotlib onnxruntime onnx

4、下载SAM-TOOL:GitCode - 全球开发者的开源社区,开源代码托管平台 ,得到文件SAM-Tool-main.zip。

5、 下载模型:https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

得到文件:sam_vit_h_4b8939.pth

6、安装:将segment-anything-main.zip解压缩,进入SAM所在的文件夹目录后,在目录下鼠标右键-->"打开于终端",然后:

pip install -e .

二、牛刀小试

1、安装完成后,在segment-anything-main目录下新建目录imgs用于存放图片,新建models目录用以存放模型权重,并新建demo.py,作为测试的脚本。

2、将前面下载得到的阶段权重模型sam_vit_h_4b8939.pth,复制到models目录。

3、将测试用的图片复制到imgs目录。图片是一个工业产品的微观图,特征如下:

图片的大小为1200*850像素,将(330,390)与(480,550)框定的范围作为采样的样本。

4、demo.py的代码内容:

import torch
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_checkpoint = "models/sam_vit_h_4b8939.pth"  # 这里填写模型权重路径
model_type = "vit_h"   # 模型类型,可选vit_h, vit_l, vit_b
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)   # 加载模型
sam.to(device=device)   # 将模型加载到GPU或CPU

# 初始化SAM预测器
predictor = SamPredictor(sam)

# 读取图片
image_path = "/home/dy/program/segment-anything-main/imgs/IMG_PP.jpg"  # 图片路径
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)   # 将图片转换为RGB格式

# 设置输入图片
predictor.set_image(image)

# 指定一些提示点,指导SAM识别特定对象
input_point = np.array([[330, 390], [480, 550]])  # 将列表转换为NumPy数组
input_label = np.array([1, 1])  # 也将标签转换为NumPy数组

# 获得掩膜
masks, scores, logits = predictor.predict(   # 返回掩膜、置信度、类别概率
    point_coords=input_point,   # 指定提示点
    point_labels=input_label,   # 指定提示点对应的标签
    multimask_output=True,     # 是否输出多类别掩膜
)

print("置信度:", scores,
      "类别概率:", logits)   # 显示置信度, 类别概率


def show_masks(masks):   # 显示掩膜
    for i, mask in enumerate(masks):
        mask_image = (mask * 255).astype("uint8")    # 将掩膜转换为8位无符号整数
        mask = Image.fromarray(mask_image)     # 将掩膜转换为PIL图像
        mask = mask.convert("L")    # 将掩膜转换为灰度图像
        plt.imshow(mask, cmap='jet', alpha=0.5)    # 显示掩膜


# 显示结果
plt.figure(figsize=(10, 10))     #  设置画布大小
plt.imshow(image)     # 显示原图
show_masks(masks)    # 显示掩膜
plt.axis('off')      #  隐藏坐标轴
plt.show()    # 显示图像

5、运行脚本,输出结果:

看得出,针对一个零样本、零训练的非标图片,sam的表现还是非常优异的。

相关推荐
__雨夜星辰__8 分钟前
Linux 学习笔记__Day2
linux·服务器·笔记·学习·centos 7
远洋录10 分钟前
构建一个数据分析Agent:提升分析效率的实践
人工智能·ai·ai agent
学问小小谢11 分钟前
第26节课:内容安全策略(CSP)—构建安全网页的防御盾
运维·服务器·前端·网络·学习·安全
IT古董1 小时前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师2 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
摸鱼仙人~3 小时前
Attention Free Transformer (AFT)-2020论文笔记
论文阅读·深度学习·transformer
python算法(魔法师版)3 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui3 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
charlie1145141914 小时前
从0开始使用面对对象C语言搭建一个基于OLED的图形显示框架(协议层封装)
c语言·驱动开发·单片机·学习·教程·oled
struggle20254 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习