利用maskrcnn来实现目标检测与追踪

首先下载源代码仓库,链接地址如下:

maskrcnn

能够实现的效果如图所示:

该存储库包括:

  • 基于FPN和ResNet101构建的Mask R-CNN的源代码。
  • MS COCO 的训练代码
  • MS COCO 的预训练砝码
  • Jupyter 笔记本,用于可视化每一步的检测管道
  • 用于多 GPU 训练的并行模型类
  • 对 MS COCO 指标 (AP) 的评估
  • 在自己的数据集上进行训练的示例

下载代码仓库,进行解压后的目录如下:

可以使用下面:

复制代码
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

也可以使用

复制代码
python setup.py install

来安装相关的依赖包,安装完成后,还需要下载模型文件,

下载链接地址如下:

mask_rcnn_balloon.h5

测试代码如下所示:

python 复制代码
import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt

# Root directory of the project
ROOT_DIR = os.path.abspath("../")

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
# Import COCO config
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))  # To find local version
import coco

%matplotlib inline 

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

# Directory of images to run detection on
IMAGE_DIR = os.path.join(ROOT_DIR, "images")

class InferenceConfig(coco.CocoConfig):
    # Set batch size to 1 since we'll be running inference on
    # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()
config.display()

# Create model object in inference mode.
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

# Load weights trained on MS-COCO
model.load_weights(COCO_MODEL_PATH, by_name=True)

# COCO Class names
# Index of the class in the list is its ID. For example, to get ID of
# the teddy bear class, use: class_names.index('teddy bear')
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
               'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
               'kite', 'baseball bat', 'baseball glove', 'skateboard',
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
               'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
               'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
               'teddy bear', 'hair drier', 'toothbrush']
# Load a random image from the images folder
file_names = next(os.walk(IMAGE_DIR))[2]
image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))

# Run detection
results = model.detect([image], verbose=1)

# Visualize results
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], 
                            class_names, r['scores'])
相关推荐
学习中的数据喵5 分钟前
机器学习之逻辑回归
人工智能·机器学习·逻辑回归
kupeThinkPoem7 分钟前
vscode中continue插件介绍
人工智能
小殊小殊14 分钟前
【论文笔记】Video-RAG:开源视频理解模型也能媲美GPT-4o
人工智能·语音识别·论文笔记
人工智能训练20 分钟前
前端框架选型破局指南:Vue、React、Next.js 从差异到落地全解析
运维·javascript·人工智能·前端框架·vue·react·next.js
IT_陈寒36 分钟前
90%的Python开发者不知道:这5个内置函数让你的代码效率提升300%
前端·人工智能·后端
吴法刚36 分钟前
Gemini cli 源码分析之Chat-ContentGenerator生成式 AI 模型交互
人工智能·microsoft·ai·gemini·ai编码
拾零吖43 分钟前
CS336 Lecture_03
人工智能·pytorch·深度学习
斯文~43 分钟前
【AI论文速递】RAG-GUI:轻量VLM用SFT/RSF提升GUI性能
人工智能·ai·agent·rag·ai读论文·ai论文速递
盼小辉丶1 小时前
视觉Transformer实战 | Token-to-Token Vision Transformer(T2T-ViT)详解与实现
pytorch·深度学习·计算机视觉·transformer