《使用 YOLOV8 和 KerasCV 进行高效目标检测》

《使用 YOLOV8 和 KerasCV 进行高效目标检测》

作者: Gitesh Chawda
创建日期: 2023/06/26
最后修改时间: 2023/06/26
**描述:**使用 KerasCV 训练自定义 YOLOV8 对象检测模型。
(i) 此示例使用 Keras 2

在 Colab 中查看

GitHub 源


介绍

KerasCV 是 Keras 的扩展,用于计算机视觉任务。在此示例中,我们将看到 如何使用 KerasCV 训练 YOLOV8 对象检测模型。

KerasCV 包括适用于常用计算机视觉数据集的预训练模型,例如 ImageNet、COCO 和 Pascal VOC,可用于迁移学习。KerasCV 还 提供了一系列用于检查中间表示的可视化工具 由模型学习,用于可视化对象检测和分割的结果 任务。

如果您有兴趣了解使用 KerasCV 进行对象检测,我强烈建议您 看看 Lukewood 创建的指南。此资源可在使用 KerasCV 进行对象检测中获得。 全面概述了基本概念和技术 使用 KerasCV 构建对象检测模型时需要。

复制代码
!pip` `install` `--upgrade` `git+https://github.com/keras-team/keras-cv` `-q`
`
复制代码
`[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv[0m[33m [0m `

设置

复制代码
import` `os`
`from` `tqdm.auto` `import` `tqdm`
`import` `xml.etree.ElementTree` `as` `ET`

`import` `tensorflow` `as` `tf`
`from` `tensorflow` `import` `keras`

`import` `keras_cv`
`from` `keras_cv` `import` `bounding_box`
`from` `keras_cv` `import` `visualization`
`
复制代码
`/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/__init__.py:98: UserWarning: unable to load libtensorflow_io_plugins.so: unable to open file: libtensorflow_io_plugins.so, from paths: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so'] caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE'] warnings.warn(f"unable to load libtensorflow_io_plugins.so: {e}") /opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/__init__.py:104: UserWarning: file system plugins are not loaded: unable to open file: libtensorflow_io.so, from paths: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so'] caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE'] warnings.warn(f"file system plugins are not loaded: {e}") `

加载数据

在本指南中,我们将使用从 roboflow 获取的自动驾驶汽车数据集。为了 使数据集更易于管理,我提取了较大数据集的子集,该子集 最初由 15,000 个数据样本组成。从这个子集中,我选择了 7,316 个 模型训练示例。

为了简化手头的任务并集中精力,我们将与减少的 对象类的数量。具体来说,我们将考虑 5 个主要类别 检测和分类:汽车、行人、红绿灯、骑自行车的人和卡车。这些 类表示 自动驾驶汽车的背景。

通过将数据集缩小到这些特定类,我们可以专注于构建 强大的对象检测模型,可以准确识别和分类这些重要 对象。

TensorFlow Datasets 库提供了一种下载和使用各种 数据集,包括对象检测数据集。对于那些人来说,这可能是一个不错的选择 想要快速开始处理数据而无需手动下载和 预处理它。

您可以在此处查看各种对象检测数据集 TensorFlow 数据集

但是,在此代码示例中,我们将演示如何从头开始加载数据集 使用 TensorFlow 的 tf.data 流水线。这种方法提供了更大的灵活性,并允许 您可以根据需要自定义预处理步骤。

加载 TensorFlow 数据集库中不可用的自定义数据集就是其中之一 使用 tf.data 管道的主要优势。此方法允许您 创建针对特定需求量身定制的自定义数据预处理管道,以及 要求。


超参数

复制代码
SPLIT_RATIO` `=` `0.2`
`BATCH_SIZE` `=` `4`
`LEARNING_RATE` `=` `0.001`
`EPOCH` `=` `5`
`GLOBAL_CLIPNORM` `=` `10.0`
`

创建一个字典以将每个类名映射到唯一的数字标识符。这 mapping 用于在训练和推理期间对类标签进行编码和解码 对象检测任务。

复制代码
class_ids` `=` `[`
    `"car",`
    `"pedestrian",`
    `"trafficLight",`
    `"biker",`
    `"truck",`
`]`
`class_mapping` `=` `dict(zip(range(len(class_ids)),` `class_ids))`

`# Path to images and annotations`
`path_images` `=` `"/kaggle/input/dataset/data/images/"`
`path_annot` `=` `"/kaggle/input/dataset/data/annotations/"`

`# Get all XML file paths in path_annot and sort them`
`xml_files` `=` `sorted(`
    `[`
        `os.path.join(path_annot,` `file_name)`
        `for` `file_name` `in` `os.listdir(path_annot)`
        `if` `file_name.endswith(".xml")`
    `]`
`)`

`# Get all JPEG image file paths in path_images and sort them`
`jpg_files` `=` `sorted(`
    `[`
        `os.path.join(path_images,` `file_name)`
        `for` `file_name` `in` `os.listdir(path_images)`
        `if` `file_name.endswith(".jpg")`
    `]`
`)`
`

下面的函数读取 XML 文件并查找图像名称和路径,然后 迭代 XML 文件中的每个对象以提取边界框坐标,并且 class 标签。

该函数返回三个值:图像路径、边界框列表(每个 表示为四个浮点数的列表:xmin、ymin、xmax、ymax)和类 ID 列表 (以整数表示)对应于每个边界框。获取类 ID 通过使用名为 的字典将类标签映射到整数值。class_mapping

复制代码
def` `parse_annotation(xml_file):`
    `tree` `=` `ET.parse(xml_file)`
    `root` `=` `tree.getroot()`

    `image_name` `=` `root.find("filename").text`
    `image_path` `=` `os.path.join(path_images,` `image_name)`

    `boxes` `=` `[]`
    `classes` `=` `[]`
    `for` `obj` `in` `root.iter("object"):`
        `cls` `=` `obj.find("name").text`
        `classes.append(cls)`

        `bbox` `=` `obj.find("bndbox")`
        `xmin` `=` `float(bbox.find("xmin").text)`
        `ymin` `=` `float(bbox.find("ymin").text)`
        `xmax` `=` `float(bbox.find("xmax").text)`
        `ymax` `=` `float(bbox.find("ymax").text)`
        `boxes.append([xmin,` `ymin,` `xmax,` `ymax])`

    `class_ids` `=` `[`
        `list(class_mapping.keys())[list(class_mapping.values()).index(cls)]`
        `for` `cls` `in` `classes`
    `]`
    `return` `image_path,` `boxes,` `class_ids`


`image_paths` `=` `[]`
`bbox` `=` `[]`
`classes` `=` `[]`
`for` `xml_file` `in` `tqdm(xml_files):`
    `image_path,` `boxes,` `class_ids` `=` `parse_annotation(xml_file)`
    `image_paths.append(image_path)`
    `bbox.append(boxes)`
    `classes.append(class_ids)`
`
复制代码
` 0%| | 0/7316 [00:00<?, ?it/s] `

在这里,我们使用 tf.ragged.constant 从 和 列表创建不规则张量。参差不齐的张量是一种可以处理不同长度的 数据。这在处理具有 可变长度序列,例如文本或时间序列数据。bbox``classes

复制代码
classes` `=` `[`
    `[8,` `8,` `8,` `8,` `8],`      `# 5 classes`
    `[12,` `14,` `14,` `14],`     `# 4 classes`
    `[1],`                  `# 1 class`
    `[7,` `7],`               `# 2 classes`
 `...]`
`
复制代码
bbox` `=` `[`
    `[[199.0,` `19.0,` `390.0,` `401.0],`
    `[217.0,` `15.0,` `270.0,` `157.0],`
    `[393.0,` `18.0,` `432.0,` `162.0],`
    `[1.0,` `15.0,` `226.0,` `276.0],`
    `[19.0,` `95.0,` `458.0,` `443.0]],`     `#image 1 has 4 objects`
    `[[52.0,` `117.0,` `109.0,` `177.0]],`   `#image 2 has 1 object`
    `[[88.0,` `87.0,` `235.0,` `322.0],`
    `[113.0,` `117.0,` `218.0,` `471.0]],`   `#image 3 has 2 objects`
 `...]`
`

在这种情况下,每个图像的 and 列表具有不同的长度, 取决于图像中的对象数量和相应的边界框,以及 类。为了处理这种可变性,使用参差不齐的张量而不是常规张量。bbox``classes

稍后,这些参差不齐的张量用于使用该方法创建 tf.data.Dataset 。该方法通过以下方式从输入张量创建数据集 沿第一维度对它们进行切片。通过使用不规则张量,数据集可以处理 每张图像的数据长度不同,并提供灵活的输入管道以进一步 加工。from_tensor_slices

复制代码
bbox` `=` `tf.ragged.constant(bbox)`
`classes` `=` `tf.ragged.constant(classes)`
`image_paths` `=` `tf.ragged.constant(image_paths)`

`data` `=` `tf.data.Dataset.from_tensor_slices((image_paths,` `classes,` `bbox))`
`

在训练和验证数据中拆分数据

复制代码
# Determine the number of validation samples`
`num_val` `=` `int(len(xml_files)` `*` `SPLIT_RATIO)`

`# Split the dataset into train and validation sets`
`val_data` `=` `data.take(num_val)`
`train_data` `=` `data.skip(num_val)`
`

让我们看看数据加载和边界框格式化以使事情顺利进行。边界 KerasCV 中的框具有预先确定的格式。为此,您必须捆绑边界 框添加到符合下列要求的词典中:

复制代码
bounding_boxes` `=` `{`
    `# num_boxes may be a Ragged dimension`
    `'boxes':` `Tensor(shape=[batch,` `num_boxes,` `4]),`
    `'classes':` `Tensor(shape=[batch,` `num_boxes])`
`}`
`

字典有两个键 和 ,每个键都映射到 TensorFlow RaggedTensor 或 Tensor 对象。Tensor 的形状为 ,其中 batch 是 batch 中的图像数,num_boxes 是 任何图像中的最大边界框数。4 表示 定义边界框:xmin、ymin、xmax、ymax。'boxes'``'classes'``'boxes'``[batch, num_boxes, 4]

Tensor 的形状为 ,其中每个元素表示 Tensor 中相应边界框的类标签。num_boxes 尺寸可能参差不齐,这意味着 批次。'classes'``[batch, num_boxes]``'boxes'

最终 dict 应该是:

复制代码
{"images":` `images,` `"bounding_boxes":` `bounding_boxes}`
`
复制代码
def` `load_image(image_path):`
    `image` `=` `tf.io.read_file(image_path)`
    `image` `=` `tf.image.decode_jpeg(image,` `channels=3)`
    `return` `image`


`def` `load_dataset(image_path,` `classes,` `bbox):`
    `# Read Image`
    `image` `=` `load_image(image_path)`
    `bounding_boxes` `=` `{`
        `"classes":` `tf.cast(classes,` `dtype=tf.float32),`
        `"boxes":` `bbox,`
    `}`
    `return` `{"images":` `tf.cast(image,` `tf.float32),` `"bounding_boxes":` `bounding_boxes}`
`

在这里,我们创建一个图层,将图像大小调整为 640x640 像素,同时保持 原始纵横比。与图像关联的边界框以格式指定。如有必要,调整大小后的图像将用零填充,以保持 原始纵横比。xyxy

KerasCV 支持的边界框格式: 1. CENTER_XYWH 2. XYWH 3. XYXY 4. REL_XYXY 5. REL_XYWH 6. YXYX 7. REL_YXYX

你可以在 docs 中阅读更多关于 KerasCV 边界框格式的信息。

此外,还可以在任意两对之间执行格式转换:

复制代码
boxes` `=` `keras_cv.bounding_box.convert_format(`
        `bounding_box,`
        `images=image,`
        `source="xyxy",`  `# Original Format`
        `target="xywh",`  `# Target Format (to which we want to convert)`
    `)`
`

数据增强

构建对象检测管道时最具挑战性的任务之一是数据 增大。它涉及对输入图像应用各种转换,以 增加训练数据的多样性,提高模型的能力 概括。但是,在处理对象检测任务时,它变得更加 复杂,因为这些转换需要了解底层边界框和 相应地更新它们。

KerasCV 为边界框增强提供原生支持。KerasCV 提供了一个 大量专为处理边界而设计的数据增强层 盒。这些图层会根据图像的原样智能地调整边界框坐标 transformed,确保边界框保持准确并与 增强图像。

通过利用 KerasCV 的功能,开发人员可以方便地集成边界 将 Box 友好的数据增强到他们的对象检测管道中。通过执行 在 tf.data 流水线中进行动态增强,该过程变得无缝且 高效,从而实现更好的训练和更准确的对象检测结果。

复制代码
augmenter` `=` `keras.Sequential(`
    `layers=[`
        `keras_cv.layers.RandomFlip(mode="horizontal",` `bounding_box_format="xyxy"),`
        `keras_cv.layers.RandomShear(`
            `x_factor=0.2,` `y_factor=0.2,` `bounding_box_format="xyxy"`
        `),`
        `keras_cv.layers.JitteredResize(`
            `target_size=(640,` `640),` `scale_factor=(0.75,` `1.3),` `bounding_box_format="xyxy"`
        `),`
    `]`
`)`
`

创建训练数据集

复制代码
train_ds` `=` `train_data.map(load_dataset,` `num_parallel_calls=tf.data.AUTOTUNE)`
`train_ds` `=` `train_ds.shuffle(BATCH_SIZE` `*` `4)`
`train_ds` `=` `train_ds.ragged_batch(BATCH_SIZE,` `drop_remainder=True)`
`train_ds` `=` `train_ds.map(augmenter,` `num_parallel_calls=tf.data.AUTOTUNE)`
`

创建验证数据集

复制代码
resizing` `=` `keras_cv.layers.JitteredResize(`
    `target_size=(640,` `640),`
    `scale_factor=(0.75,` `1.3),`
    `bounding_box_format="xyxy",`
`)`

`val_ds` `=` `val_data.map(load_dataset,` `num_parallel_calls=tf.data.AUTOTUNE)`
`val_ds` `=` `val_ds.shuffle(BATCH_SIZE` `*` `4)`
`val_ds` `=` `val_ds.ragged_batch(BATCH_SIZE,` `drop_remainder=True)`
`val_ds` `=` `val_ds.map(resizing,` `num_parallel_calls=tf.data.AUTOTUNE)`
`

可视化

复制代码
def` `visualize_dataset(inputs,` `value_range,` `rows,` `cols,` `bounding_box_format):`
    `inputs` `=` `next(iter(inputs.take(1)))`
    `images,` `bounding_boxes` `=` `inputs["images"],` `inputs["bounding_boxes"]`
    `visualization.plot_bounding_box_gallery(`
        `images,`
        `value_range=value_range,`
        `rows=rows,`
        `cols=cols,`
        `y_true=bounding_boxes,`
        `scale=5,`
        `font_scale=0.7,`
        `bounding_box_format=bounding_box_format,`
        `class_mapping=class_mapping,`
    `)`


`visualize_dataset(`
    `train_ds,` `bounding_box_format="xyxy",` `value_range=(0,` `255),` `rows=2,` `cols=2`
`)`

`visualize_dataset(`
    `val_ds,` `bounding_box_format="xyxy",` `value_range=(0,` `255),` `rows=2,` `cols=2`
`)`
`

我们需要从 preprocessing 字典中提取输入并准备好它们 馈送到模型中。

复制代码
def` `dict_to_tuple(inputs):`
    `return` `inputs["images"],` `inputs["bounding_boxes"]`


`train_ds` `=` `train_ds.map(dict_to_tuple,` `num_parallel_calls=tf.data.AUTOTUNE)`
`train_ds` `=` `train_ds.prefetch(tf.data.AUTOTUNE)`

`val_ds` `=` `val_ds.map(dict_to_tuple,` `num_parallel_calls=tf.data.AUTOTUNE)`
`val_ds` `=` `val_ds.prefetch(tf.data.AUTOTUNE)`
`

创建模型

YOLOv8 是一款尖端的 YOLO 模型,用于各种计算机视觉任务, 例如对象检测、图像分类和实例分割。Ultralytics, YOLOv5 的创建者还开发了 YOLOv8,其中包含许多改进和 与前代产品相比,架构和开发人员体验发生了变化。YOLOv8 是 在业内受到高度评价的最新最新型号。

下表比较了 5 种不同 YOLOv8 模型的性能指标与 不同大小(以像素为单位):YOLOv8n、YOLOv8s、YOLOv8m、YOLOv8l 和 YOLOv8x。 这些指标包括不同 验证数据的交并比 (IoU) 阈值,CPU 上的推理速度 ONNX 格式和 A100 TensorRT 、参数数量和浮点数 操作 (FLOP)(分别以百万和数十亿为单位)。由于 model 增加时,mAP、参数和 FLOPs 通常增加,而速度 减少。YOLOv8x 的 mAP、参数和 FLOP 最高,但也是最慢的 推理速度,而 YOLOv8n 具有最小的尺寸、最快的推理速度和最低的推理速度 mAP、参数和 FLOPs。

您可以在此 RoboFlow 博客中阅读有关 YOLOV8 及其架构的更多信息

首先,我们将创建一个 backbone 实例,供我们的 yolov8 检测器使用 类。

KerasCV 中提供的 YOLOV8 Backbones:

  1. 无权重:
复制代码
`1. yolo_v8_xs_backbone 2. yolo_v8_s_backbone 3. yolo_v8_m_backbone 4. yolo_v8_l_backbone 5. yolo_v8_xl_backbone `
  1. 使用预先训练的 coco 重量:
复制代码
backbone` `=` `keras_cv.models.YOLOV8Backbone.from_preset(`
    `"yolo_v8_s_backbone_coco"`  `# We will use yolov8 small backbone with coco weights`
`)`
`
复制代码
`1. yolo_v8_xs_backbone_coco 2. yolo_v8_s_backbone_coco 2. yolo_v8_m_backbone_coco 2. yolo_v8_l_backbone_coco 2. yolo_v8_xl_backbone_coco Downloading data from https://storage.googleapis.com/keras-cv/models/yolov8/coco/yolov8_s_backbone.h5 20596968/20596968 [==============================] - 0s 0us/step `

接下来,让我们使用 构建一个 YOLOV8 模型,它接受一个特征 extractor 作为参数,则指定数字 of 对象类来根据列表的大小进行检测,该参数通知模型 数据集,最后,特征金字塔网络 (FPN) 深度由参数指定。YOLOV8Detector``backbone``num_classes``class_mapping``bounding_box_format``fpn_depth

使用上述任何 backbone 构建 YOLOV8 都很简单,这要归功于 KerasCV 的

复制代码
yolo` `=` `keras_cv.models.YOLOV8Detector(`
    `num_classes=len(class_mapping),`
    `bounding_box_format="xyxy",`
    `backbone=backbone,`
    `fpn_depth=1,`
`)`
`

编译模型

用于 YOLOV8 的损失

  1. 分类损失:此损失函数计算预期 类概率和实际类概率。在这种情况下,二进制分类问题的一个突出解决方案是 利用。我们利用了二进制交叉熵,因为每个被识别的事物都是 被归类为属于或不属于某个对象类(例如,一个人、一个 汽车等)。binary_crossentropy

  2. Box Loss:是用于衡量 预测边界框和地面实况。在这种情况下,完整 IoU (CIoU) 指标,它不仅衡量预测值和真实值之间的重叠 边界框,但还要考虑纵横比、中心距和 盒子大小。这些损失函数共同帮助优化对象检测模型,方法是 最小化 Predicted 和 Ground Truth 类概率之间的差异,以及 边界框。box_loss

复制代码
optimizer` `=` `tf.keras.optimizers.Adam(`
    `learning_rate=LEARNING_RATE,`
    `global_clipnorm=GLOBAL_CLIPNORM,`
`)`

`yolo.compile(`
    `optimizer=optimizer,` `classification_loss="binary_crossentropy",` `box_loss="ciou"`
`)`
`

COCO 指标回调

我们将使用 KerasCV 来评估模型并计算 Map(Mean Average Precision) 分数、Recall 和 Precision。我们还会在 mAP 评分提高。BoxCOCOMetrics

复制代码
class` `EvaluateCOCOMetricsCallback(keras.callbacks.Callback):`
    `def` `__init__(self,` `data,` `save_path):`
        `super().__init__()`
        `self.data` `=` `data`
        `self.metrics` `=` `keras_cv.metrics.BoxCOCOMetrics(`
            `bounding_box_format="xyxy",`
            `evaluate_freq=1e9,`
        `)`

        `self.save_path` `=` `save_path`
        `self.best_map` `=` `-1.0`

    `def` `on_epoch_end(self,` `epoch,` `logs):`
        `self.metrics.reset_state()`
        `for` `batch` `in` `self.data:`
            `images,` `y_true` `=` `batch[0],` `batch[1]`
            `y_pred` `=` `self.model.predict(images,` `verbose=0)`
            `self.metrics.update_state(y_true,` `y_pred)`

        `metrics` `=` `self.metrics.result(force=True)`
        `logs.update(metrics)`

        `current_map` `=` `metrics["MaP"]`
        `if` `current_map` `>` `self.best_map:`
            `self.best_map` `=` `current_map`
            `self.model.save(self.save_path)`  `# Save the model when mAP improves`

        `return` `logs`
`

训练模型

复制代码
yolo.fit(`
    `train_ds,`
    `validation_data=val_ds,`
    `epochs=3,`
    `callbacks=[EvaluateCOCOMetricsCallback(val_ds,` `"model.h5")],`
`)`
`
复制代码
`Epoch 1/3 1463/1463 [==============================] - 633s 390ms/step - loss: 10.1535 - box_loss: 2.5659 - class_loss: 7.5876 - val_loss: 3.9852 - val_box_loss: 3.1973 - val_class_loss: 0.7879 - MaP: 0.0095 - MaP@[IoU=50]: 0.0193 - MaP@[IoU=75]: 0.0074 - MaP@[area=small]: 0.0021 - MaP@[area=medium]: 0.0164 - MaP@[area=large]: 0.0010 - Recall@[max_detections=1]: 0.0096 - Recall@[max_detections=10]: 0.0160 - Recall@[max_detections=100]: 0.0160 - Recall@[area=small]: 0.0034 - Recall@[area=medium]: 0.0283 - Recall@[area=large]: 0.0010 Epoch 2/3 1463/1463 [==============================] - 554s 378ms/step - loss: 2.6961 - box_loss: 2.2861 - class_loss: 0.4100 - val_loss: 3.8292 - val_box_loss: 3.0052 - val_class_loss: 0.8240 - MaP: 0.0077 - MaP@[IoU=50]: 0.0197 - MaP@[IoU=75]: 0.0043 - MaP@[area=small]: 0.0075 - MaP@[area=medium]: 0.0126 - MaP@[area=large]: 0.0050 - Recall@[max_detections=1]: 0.0088 - Recall@[max_detections=10]: 0.0154 - Recall@[max_detections=100]: 0.0154 - Recall@[area=small]: 0.0075 - Recall@[area=medium]: 0.0191 - Recall@[area=large]: 0.0280 Epoch 3/3 1463/1463 [==============================] - 558s 381ms/step - loss: 2.5930 - box_loss: 2.2018 - class_loss: 0.3912 - val_loss: 3.4796 - val_box_loss: 2.8472 - val_class_loss: 0.6323 - MaP: 0.0145 - MaP@[IoU=50]: 0.0398 - MaP@[IoU=75]: 0.0072 - MaP@[area=small]: 0.0077 - MaP@[area=medium]: 0.0227 - MaP@[area=large]: 0.0079 - Recall@[max_detections=1]: 0.0120 - Recall@[max_detections=10]: 0.0257 - Recall@[max_detections=100]: 0.0258 - Recall@[area=small]: 0.0093 - Recall@[area=medium]: 0.0396 - Recall@[area=large]: 0.0226 <keras.callbacks.History at 0x7f3e01ca6d70> `

可视化预测

复制代码
def` `visualize_detections(model,` `dataset,` `bounding_box_format):`
    `images,` `y_true` `=` `next(iter(dataset.take(1)))`
    `y_pred` `=` `model.predict(images)`
    `y_pred` `=` `bounding_box.to_ragged(y_pred)`
    `visualization.plot_bounding_box_gallery(`
        `images,`
        `value_range=(0,` `255),`
        `bounding_box_format=bounding_box_format,`
        `y_true=y_true,`
        `y_pred=y_pred,`
        `scale=4,`
        `rows=2,`
        `cols=2,`
        `show=True,`
        `font_scale=0.7,`
        `class_mapping=class_mapping,`
    `)`


`visualize_detections(yolo,` `dataset=val_ds,` `bounding_box_format="xyxy")`
`
复制代码
`1/1 [==============================] - 0s 115ms/step `
相关推荐
huaqianzkh24 分钟前
理解构件的3种分类方法
人工智能·分类·数据挖掘
后端码匠25 分钟前
Spring Boot3+Vue2极速整合:10分钟搭建DeepSeek AI对话系统
人工智能·spring boot·后端
用户2314349781426 分钟前
使用 Trae AI 编程平台生成扫雷游戏
人工智能·设计
神经美学_茂森40 分钟前
神经网络防“失忆“秘籍:弹性权重固化如何让AI学会“温故知新“
人工智能·深度学习·神经网络
大囚长42 分钟前
AI工作流+专业知识库+系统API的全流程任务自动化
运维·人工智能·自动化
阿_旭44 分钟前
【超详细】神经网络的可视化解释
人工智能·深度学习·神经网络
Se7en2581 小时前
提升 AI 服务的稳定性:Higress AI 网关的降级功能介绍
人工智能
武乐乐~1 小时前
QARepVGG--含demo实现
深度学习
机器视觉知识推荐、就业指导1 小时前
【数字图像处理二】图像增强与空域处理
图像处理·人工智能·经验分享·算法·计算机视觉
陈辛chenxin1 小时前
【论文带读系列(1)】《End-to-End Object Detection with Transformers》论文超详细带读 + 翻译
人工智能·目标检测·计算机视觉