基于 Python 和 HuggingFace Transformers 进行目标检测

YOLO!如果你对机器学习感兴趣,这个术语应该听过。确实,You Only Look Once(你只需看一次)在过去几年中一直是目标检测的默认方式之一。在卷积神经网络取得进展的推动下,已经创建了许多版本的目标检测方法。然而,近来出现了一位竞争对手------即在计算机视觉中使用基于Transformer的模型,更具体地说是在目标检测中使用Transformers。

在今天的文章中,你将学到有关这种类型的Transformer模型的知识。你还将学会使用Python、默认的Transformer模型和HuggingFace Transformers库创建自己的目标检测流程。实际上,这将非常容易,让我们一起来看看吧!

阅读完本文后,你将会:

  • 了解目标检测可用于什么。

  • 了解在目标检测中使用Transformer模型的工作原理。

  • 已经使用Python和HuggingFace Transformers实现了基于Transformer模型的(图像)目标检测流程。

什么是目标检测?

看看你周围。很可能,你会看到很多东西------可能是计算机显示器、键盘和鼠标,或者当你在移动浏览器中浏览时,可能是智能手机。这些都是目标,是特定类别的实例。例如,在下面的图像中,我们看到了一个属于"人类"类的实例。我们还看到了许多属于"瓶子"类的实例。虽然类是一个蓝图,但对象是真实的东西,它具有许多独特的特征,同时因为共享特征而属于类的成员。

在图片和视频中,我们看到许多这样的对象。例如,当你制作交通视频时,很可能会看到许多"行人"、"汽车"和"自行车"等的实例。知道它们在图像中可以是非常有益的!为什么?因为你可以数它们,这只是一个例子。这使你能够对社区的拥挤程度有所了解。另一个例子是在繁忙地区检测停车位,让你可以停车。这就是目标检测的用途!

目标检测和Transformers

传统上,目标检测是使用卷积神经网络来执行的。通常,它们的架构专门针对目标检测进行了定制,因为它们以图像作为输入,并输出图像的边界框。如果你熟悉神经网络,你会知道在学习图像中的重要特征方面,ConvNets非常有用,并且它们在空间上是不变的------换句话说,学习的对象在图像中的位置或大小是无关紧要的。如果网络能够看到对象的特征并将其与特定类别关联起来,那么它就能够识别它。例如,可以将许多不同的猫识别为"猫"类的实例。

然而,最近,Transformer架构在深度学习领域,特别是在自然语言处理领域,引起了极大的关注。Transformers通过将输入编码为高维状态,然后将其解码回所需的输出来工作。通过巧妙地使用自注意力的概念,Transformers不仅学会检测特定模式,还学会将这些模式与其他模式关联起来。在上面的猫的例子中,举个例子,Transformers可以学会将猫与其特征性的斑点关联起来------比如沙发,只是一个想法。

如果Transformers可以用于图像分类,那么将它们用于目标检测就只是更进一步。Carion等人(2020)表明实际上可以使用基于Transformer的架构来进行这样的操作。在他们的工作"使用Transformers进行端到端目标检测"中,他们引入了Detection Transformer或DETR,我们将在今天创建目标检测流程中使用它。

它的工作原理如下,并且甚至不完全放弃了卷积神经网络:

  1. 使用卷积神经网络从输入图像中提取重要特征。这些特征被位置编码,就像语言Transformers中一样,以帮助神经网络学习这些特征在图像中的位置。

  2. 输入被扁平化,然后使用变压器编码器和注意力编码为中间状态。

  3. 变压器解码器的输入是这个状态和在训练过程中获得的一组学习到的对象查询。你可以将它们想象成问题,问:"这里是否有一个对象,因为在许多情况下我曾经见过一个?" 这将通过使用中间状态来回答。

  4. 实际上,解码器的输出是通过多个预测头进行的一组预测:每个查询一个。由于DETR中默认设置了100个查询,它只能在一张图像中预测100个对象,除非你进行其他配置。

Transformers如何用于目标检测

HuggingFace Transformers和ObjectDetectionPipeline

现在你了解了DETR的工作原理,是时候用它来创建一个实际的目标检测流程了!我们将使用HuggingFace Transformers,这是为了使NLP和计算机视觉Transformers易于使用而构建的。实际上,它非常容易使用,只需加载ObjectDetectionPipeline即可------默认情况下加载一个使用ResNet-50骨干训练的DETR Transformer的管道,用于生成图像特征。

ObjectDetectionPipeline可以轻松地作为管道实例进行初始化...换句话说,通过`pipeline("object-detection")`,我们将在下面的示例中看到这一点。当你没有提供其他输入时,根据GitHub(n.d.)的说法,这就是管道的初始化方式。

javascript 复制代码
"object-detection": {
        "impl": ObjectDetectionPipeline,
        "tf": (),
        "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
        "default": {"model": {"pt": "facebook/detr-resnet-50"}},
        "type": "image",
    },

毫不奇怪,使用了专为目标检测定制的ObjectDetectionPipeline实例。在HuggingFace Transformers的PyTorch版本中,为此目的使用了AutoModelForObjectDetection。有趣的是,在TensorFlow版本中,目前尚未提供此管道的实现... 还没呢?正如你所学到的,默认情况下使用了facebook/detr-resnet-50模型来提取图像特征:

DEtection TRansformer(DETR)模型是在COCO 2017目标检测数据集(118,000个带有注释的图像)上进行端到端训练的。它在Carion等人的论文"使用Transformers进行端到端目标

HuggingFace(n.d)

COCO数据集(Common Objects in Context)是用于目标检测模型的标准数据集之一,并且被用于训练此模型。不用担心,你显然也可以训练自己的基于DETR的模型。

重要提示!要使用ObjectDetectionPipeline,安装包含PyTorch图像模型的timm包非常重要。确保在尚未安装时运行以下命令:`pip install timm`。

使用Python实现简单的目标检测流程

现在让我们看看如何使用Python实现简单的目标检测解决方案。回想一下,你正在使用HuggingFace Transformers,它必须安装在你的系统上------如果还没有,请运行`pip install transformers`。

我还假设PyTorch,这些天深度学习中的主要库之一,已经安装。请记住,上面在调用`pipeline("object-detection")`时在后台将加载ObjectDetectionPipeline,而TensorFlow中没有这个实例,因此必须使用PyTorch。

以下是我们将在本文后面为其创建的目标检测流程的图像:

我们从导入开始:

javascript 复制代码
from transformers import pipeline
from PIL import Image, ImageDraw, ImageFont

显然,我们使用transformers,特别是它的管道表示。然后,我们还使用PIL,这是一个用于加载、可视化和处理图像的Python库。具体来说,我们使用第一个导入------加载图像的Image,用于绘制边界框和标签的ImageDraw,后者还需要ImageFont。说到这两者,接下来是加载字体(我们选择Arial)并初始化上面介绍的目标检测管道。

ini 复制代码
# Load font
font = ImageFont.truetype("arial.ttf", 40)
# Initialize the object detection pipeline
object_detector = pipeline("object-detection")

然后,我们创建一个名为`draw_bounding_box` 的定义,可以不出所料地用于绘制边界框。它将图像(im)、类别概率、边界框的坐标、此定义将用于的边界框列表中的边界框索引以及该列表的长度作为输入。

  • 首先,在图像顶部绘制实际边界框,表示为一个红色的`rounded_rectangle`边界框,具有较小的半径以确保边缘平滑。

  • 其次,在边界框的正上方绘制文本标签。

  • 最后,返回中间结果,以便我们可以在其上绘制下一个边界框和标签。

python 复制代码
# Draw bounding box definition
def draw_bounding_box(im, score, label, xmin, ymin, xmax, ymax, index, num_boxes):
 """ Draw a bounding box. """


 print(f"Drawing bounding box {index} of {num_boxes}...")


 # Draw the actual bounding box
 im_with_rectangle = ImageDraw.Draw(im)  
 im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = "red", width = 5, radius = 10)


 # Draw the label
 im_with_rectangle.text((xmin+35, ymin-25), label, fill="white", stroke_fill = "red", font = font)


 # Return the intermediate result
 return im_with_rectangle

接下来是核心部分------使用管道,然后根据其结果绘制边界框。步骤如下:

  • 图像,我们称之为`street.jpg`,它位于与Python脚本相同的目录中,将被打开并存储在一个`im` PIL对象中。我们简单地将其提供给已初始化的`object_detector`,这就足以使模型返回边界框!Transformers库会处理其余部分。

  • 我们将数据分配给一些变量,并遍历每个结果,绘制边界框。

  • 我们保存图像为`street_bboxes.jpg`。

sql 复制代码
# Open the image
with Image.open("street.jpg") as im:
    # Perform object detection
    bounding_boxes = object_detector(im)
    # Iteration elements
    num_boxes = len(bounding_boxes)
    index = 0
    # Draw bounding box for each result
    for bounding_box in bounding_boxes:
        # Get actual box
        box = bounding_box["box"]
        # Draw the bounding box
        im = draw_bounding_box(im, bounding_box["score"], bounding_box["label"],
                                box["xmin"], box["ymin"], box["xmax"], box["ymax"], index, num_boxes)
        # Increase index by one
        index += 1
    # Save image
    im.save("street_bboxes.jpg")
    # Done
    print("Done!")

使用不同模型进行目标检测

如果你创建了自己的模型,或者想要使用不同的模型,那么很容易使用它来替代基于ResNet-50的DETR Transformer。这将要求你将以下内容添加到导入中:

javascript 复制代码
from transformers import DetrFeatureExtractor, DetrForObjectDetection

然后,你可以初始化特征提取器和模型,并用它们初始化`object_detector`,而不是默认的一个。例如,如果你想要使用ResNet-101作为骨干,可以这样做:

ini 复制代码
# Initialize another model and feature extractor
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-101')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')


# Initialize the object detection pipeline
object_detector = pipeline("object-detection", model = model, feature_extractor = feature_extractor)

结果

在我们的输入图像上运行目标检测流程后,我们得到了以下结果:

目标检测示例 --- 完整代码

python 复制代码
from transformers import pipeline
from PIL import Image, ImageDraw, ImageFont




# Load font
font = ImageFont.truetype("arial.ttf", 40)


# Initialize the object detection pipeline
object_detector = pipeline("object-detection")




# Draw bounding box definition
def draw_bounding_box(im, score, label, xmin, ymin, xmax, ymax, index, num_boxes):
 """ Draw a bounding box. """


 print(f"Drawing bounding box {index} of {num_boxes}...")


 # Draw the actual bounding box
 im_with_rectangle = ImageDraw.Draw(im)  
 im_with_rectangle.rounded_rectangle((xmin, ymin, xmax, ymax), outline = "red", width = 5, radius = 10)


 # Draw the label
 im_with_rectangle.text((xmin+35, ymin-25), label, fill="white", stroke_fill = "red", font = font)


 # Return the intermediate result
 return im




# Open the image
with Image.open("street.jpg") as im:


 # Perform object detection
 bounding_boxes = object_detector(im)


 # Iteration elements
 num_boxes = len(bounding_boxes)
 index = 0


 # Draw bounding box for each result
 for bounding_box in bounding_boxes:


  # Get actual box
  box = bounding_box["box"]


  # Draw the bounding box
  im = draw_bounding_box(im, bounding_box["score"], bounding_box["label"],\
   box["xmin"], box["ymin"], box["xmax"], box["ymax"], index, num_boxes)


  # Increase index by one
  index += 1


 # Save image
 im.save("street_bboxes.jpg")


 # Done
 print("Done!")

· END ·

HAPPY LIFE

本文仅供学习交流使用,如有侵权请联系作者删除

相关推荐
刀客1231 分钟前
python3+TensorFlow 2.x(四)反向传播
人工智能·python·tensorflow
SpikeKing7 分钟前
LLM - 大模型 ScallingLaws 的设计 100B 预训练方案(PLM) 教程(5)
人工智能·llm·预训练·scalinglaws·100b·deepnorm·egs
stevewongbuaa16 分钟前
一些烦人的go设置 goland
开发语言·后端·golang
小枫@码31 分钟前
免费GPU算力,不花钱部署DeepSeek-R1
人工智能·语言模型
liruiqiang0532 分钟前
机器学习 - 初学者需要弄懂的一些线性代数的概念
人工智能·线性代数·机器学习·线性回归
撸码到无法自拔35 分钟前
MATLAB中处理大数据的技巧与方法
大数据·开发语言·matlab
Icomi_36 分钟前
【外文原版书阅读】《机器学习前置知识》1.线性代数的重要性,初识向量以及向量加法
c语言·c++·人工智能·深度学习·神经网络·机器学习·计算机视觉
微学AI39 分钟前
GPU算力平台|在GPU算力平台部署可图大模型Kolors的应用实战教程
人工智能·大模型·llm·gpu算力
西猫雷婶41 分钟前
python学opencv|读取图像(四十六)使用cv2.bitwise_or()函数实现图像按位或运算
人工智能·opencv·计算机视觉
IT古董42 分钟前
【深度学习】常见模型-生成对抗网络(Generative Adversarial Network, GAN)
人工智能·深度学习·生成对抗网络