yolov8训练及测试(ubuntu18.04、tensorrt、ros)

1 数据集制作

1.1标注数据

Linux/Ubuntu/Mac

至少需要 Python 2.6 (推荐使用 Python 3 或更高版本 及 PyQt5)

Ubuntu Linux (Python 3 + Qt5)

csharp 复制代码
git clone https://gitcode.com/gh_mirrors/la/labelImg.git
sudo apt-get install pyqt5-dev-tools
cd labelImg
sudo pip3 install -r requirements/requirements-linux-python3.txt
make qt5py3
python3 labelImg.py

运行python3 labelImg.py出错, File "/home/wyh/environment_setting/labelImg-master/libs/labelDialog.py", line 37, in __init__ layout.addWidget(bb, alignment=Qt.AlignmentFlag.AlignLeft) AttributeError: type object 'AlignmentFlag' has no attribute 'AlignLeft'

原因:因为 PyQtPySide 的版本问题

解决:如果确定用的时PYQT5,将layout.addWidget(bb, alignment=Qt.AlignmentFlag.AlignLeft)更改为layout.addWidget(bb, alignment=Qt.AlignLeft)

1.2 建立对应的数据文件夹


images:图片数据,labels:标注转换后的yolotxt文件,xmls:labelimg标注的xml格式数据,class.txt:标签txt文件

1.3 将标注后的xml转为txt

csharp 复制代码
#! /usr/local/bin/ python
# -*- coding: utf-8 -*-
# .xml文件转换成.txt文件
 
import copy
from xml.etree import Element, SubElement, tostring, ElementTree
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
 
# 检测目标的类别
classes = ["ore carrier", "passenger ship",
           "container ship", "bulk cargo carrier",
           "general cargo ship", "fishing boat"]
 
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
 
def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[1]) / 2.0    # (x_min + x_max) / 2.0
    y = (box[2] + box[3]) / 2.0    # (y_min + y_max) / 2.0
    w = box[1] - box[0]   # x_max - x_min
    h = box[3] - box[2]   # y_max - y_min
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)
 
def convert_annotation(image_id):
    # .xml格式文件的地址
    in_file = open('地址1\%s.xml' % (image_id), encoding='UTF-8')
 
    # 生成的.txt格式文件的地址
    out_file = open('地址2\%s.txt' % (image_id), 'w')
    
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
 
    for obj in root.iter('object'):
        cls = obj.find('name').text
        
        if cls not in classes:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
 
# .xml格式文件的地址
xml_path = os.path.join(CURRENT_DIR, '地址1/')
 
# xml列表
img_xmls = os.listdir(xml_path)
for img_xml in img_xmls:
    label_name = img_xml.split('.')[0]
    print(label_name)
    convert_annotation(label_name)

将代码中路径更改为对应的路径

2 将yolo数据拆分为train、val、test

csharp 复制代码
import os
import random
import shutil

def split_dataset(images_dir, labels_dir, output_dir, split_ratio=(0.8, 0.1, 0.1)):
    """
    将图像和标签数据集划分为训练集、验证集和测试集。

    :param images_dir: 图像文件夹路径
    :param labels_dir: 标签文件夹路径
    :param output_dir: 输出目录路径
    :param split_ratio: 划分比例 (train, val, test)
    """
    # 确保输出目录存在
    os.makedirs(output_dir, exist_ok=True)
    for subdir in ['train', 'val', 'test']:
        os.makedirs(os.path.join(output_dir, subdir, 'images'), exist_ok=True)
        os.makedirs(os.path.join(output_dir, subdir, 'labels'), exist_ok=True)

    # 获取所有图像文件名
    images = [f for f in os.listdir(images_dir) if f.endswith('.jpg') or f.endswith('.png')]
    labels = [f.replace('.jpg', '.txt').replace('.png', '.txt') for f in images]

    # 打乱顺序
    combined = list(zip(images, labels))
    random.shuffle(combined)
    images[:], labels[:] = zip(*combined)

    # 计算划分点
    num_train = int(len(images) * split_ratio[0])
    num_val = int(len(images) * split_ratio[1])

    # 划分数据集
    for i, image in enumerate(images):
        label = labels[i]
        if i < num_train:
            subset = 'train'
        elif i < num_train + num_val:
            subset = 'val'
        else:
            subset = 'test'

        shutil.copy(os.path.join(images_dir, image), os.path.join(output_dir, subset, 'images', image))
        shutil.copy(os.path.join(labels_dir, label), os.path.join(output_dir, subset, 'labels', label))

# 示例调用
split_dataset('/home/wyh/artrc_catkin/src/artrc_yolov8/datasets/origin_data/images',
              '/home/wyh/artrc_catkin/src/artrc_yolov8/datasets/origin_data/labels',
              '/home/wyh/artrc_catkin/src/artrc_yolov8/datasets/split_data')

运行后如图所示

3 根据数据集添加yaml文件

csharp 复制代码
import yaml
import os
def create_yaml(output_dir, train_dir, val_dir, test_dir, class_names, num_classes):
    """
    创建 YOLOv8 数据集配置文件。

    :param output_dir: 输出目录路径
    :param train_dir: 训练集目录路径
    :param val_dir: 验证集目录路径
    :param test_dir: 测试集目录路径
    :param class_names: 类别名称列表
    :param num_classes: 类别数量
    """
    data = {
        'train': train_dir,
        'val': val_dir,
        'test': test_dir,
        'nc': num_classes,
        'names': class_names
    }
    with open(os.path.join(output_dir, 'dataset.yaml'), 'w') as f:
        yaml.dump(data, f, default_flow_style=False)

# 示例调用
create_yaml('/home/wyh/artrc_catkin/src/artrc_yolov8/datasets/split_data',
            '/home/wyh/artrc_catkin/src/artrc_yolov8/datasets/split_data/train/images',
            '/home/wyh/artrc_catkin/src/artrc_yolov8/datasets/split_data/val/images',
            '/home/wyh/artrc_catkin/src/artrc_yolov8/datasets/split_data/test/images',
            ['corrosion','craze', 'hide_craze','surface_attach','surface_corrosion','surface_eye',
        'surface_injure','surface_oil','thunderstrike'], 9)

运行结果如下文件:

4 训练数据集

csharp 复制代码
cd ultralytics
yolo task=detect mode=train model=yolov8n.pt data=ultralytics/cfg/datasets/dataset.yaml batch=8 epochs=200 imgsz=640 workers=32 device=0

5 训练后使用

5.1 训练后的各中形式数据转换

5.1.1 将.pt转换为onnx

方式一:利用下述pt_to_onnx.py进行转换

csharp 复制代码
#! /usr/local/bin/ python
# -*- coding: utf-8 -*-
from ultralytics import YOLO
 
model = YOLO("best.pt")
 
success = model.export(format="onnx", half=False, dynamic=True, opset=17)
 
print("demo")
csharp 复制代码
cd ultralytics
python pt_to_onnx.py

方式二:命令行操作转换

csharp 复制代码
# 到相应的权重文件所在文件夹
cd ultralytics 
setconda
conda activate yolov8
yolo mode=export model=yolov8n.pt format=onnx dynamic=True    #simplify=True
yolo mode=export model=yolov8s.pt format=onnx dynamic=True    # 不同模型

5.1.2将.onnx转换为.trt

csharp 复制代码
cd /environment_setting/tensorrt-alpha/data/yolov8
# 生成trt文件
# 640  ../../../TensorRT-8.4.1.5/bin/trtexec为各路径,根据实际情况填写
../../../TensorRT-8.4.1.5/bin/trtexec   --onnx=best.onnx  --saveEngine=best.trt  --buildOnly --minShapes=images:1x3x640x640 --optShapes=images:4x3x640x640 --maxShapes=images:8x3x640x640
../../../TensorRT-8.4.1.5/bin/trtexec   --onnx=yolov8s.onnx  --saveEngine=yolov8s.trt  --buildOnly --minShapes=images:1x3x640x640 --optShapes=images:4x3x640x640 --maxShapes=images:8x3x640x640
../../../TensorRT-8.4.1.5/bin/trtexec   --onnx=yolov8m.onnx  --saveEngine=yolov8m.trt  --buildOnly --minShapes=images:1x3x640x640 --optShapes=images:4x3x640x640 --maxShapes=images:8x3x640x640

5.2 利用pt文件进行检测

csharp 复制代码
#!/home/wyh/.conda/envs/yolov8/bin/python3.8
# -*- coding: utf-8 -*-
import cv2
import torch
import rospy
import numpy as np
from ultralytics import YOLO
from time import time
from std_msgs.msg import Header
from sensor_msgs.msg import Image
from artrc_yolov8.msg import BoundingBox, BoundingBoxes

class Yolo_Dect:
    def __init__(self):
        # load parameters
        weight_path = rospy.get_param('~weight_path', '')
        image_topic = rospy.get_param(
            '~image_topic', '/camera/color/image_raw')
        pub_topic = rospy.get_param('~pub_topic', '/yolov8/BoundingBoxes')
        self.camera_frame = rospy.get_param('~camera_frame', '')
        conf = rospy.get_param('~conf', '0.5')
        self.visualize = rospy.get_param('~visualize', 'True')

        # which device will be used
        if (rospy.get_param('/use_cpu', 'true')):
            self.device = 'cpu'
        else:
            self.device = 'cuda'

        self.model = YOLO(weight_path)
        self.model.fuse()

        self.model.conf = conf
        self.color_image = Image()
        self.getImageStatus = False

        # Load class color
        self.classes_colors = {}

        # image subscribe
        self.color_sub = rospy.Subscriber(image_topic, Image, self.image_callback,
                                          queue_size=1, buff_size=52428800)

        # output publishers
        self.position_pub = rospy.Publisher(
            pub_topic,  BoundingBoxes, queue_size=1)

        self.image_pub = rospy.Publisher(
            '/yolov8/detection_image',  Image, queue_size=1)

        # Load image and detect
        self.load_and_detect()

    def image_callback(self, image):
        # Existing image callback logic
        pass

    def load_and_detect(self):
        # Load image from file or a specific source
        image_path = '/home/wyh/artrc_catkin/src/artrc_yolov8/image/60.jpg'  # Replace with your image path
        self.color_image = cv2.imread(image_path)
        if self.color_image is None:
            rospy.logerr("Failed to load image from path: %s", image_path)
            return

        self.color_image = cv2.cvtColor(self.color_image, cv2.COLOR_BGR2RGB)

        results = self.model(self.color_image, show=False, conf=0.3)

        self.dectshow(results, self.color_image.shape[0], self.color_image.shape[1])

        cv2.waitKey(3)

    def dectshow(self, results, height, width):
        # Existing detection logic
        self.frame = results[0].plot()
        print(str(results[0].speed['inference']))
        fps = 1000.0 / results[0].speed['inference']
        cv2.putText(self.frame, f'FPS: {int(fps)}', (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2, cv2.LINE_AA)

        self.boundingBoxes = BoundingBoxes()
        self.boundingBoxes.header = Header(stamp=rospy.Time.now())
        self.boundingBoxes.image_header = Header(stamp=rospy.Time.now())

        # 统计数量
        class_count = {}
        total_count = 0

        for result in results[0].boxes:
            boundingBox = BoundingBox()
            boundingBox.xmin = np.int64(result.xyxy[0][0].item())
            boundingBox.ymin = np.int64(result.xyxy[0][1].item())
            boundingBox.xmax = np.int64(result.xyxy[0][2].item())
            boundingBox.ymax = np.int64(result.xyxy[0][3].item())
            boundingBox.Class = results[0].names[result.cls.item()]
            boundingBox.probability = result.conf.item()
            self.boundingBoxes.bounding_boxes.append(boundingBox)
        
            if boundingBox.Class in class_count:
                class_count[boundingBox.Class] += 1
            else:
                class_count[boundingBox.Class] = 1
            total_count += 1
            print("cl:",boundingBox.Class)


        self.position_pub.publish(self.boundingBoxes)
        self.publish_image(self.frame, height, width)
        print("data",self.boundingBoxes)
        print("Class Count:", class_count)
        print("total count:",total_count)
        # if self.visualize:
            # cv2.imshow('YOLOv8', self.frame)

    def publish_image(self, imgdata, height, width):
        image_temp = Image()
        header = Header(stamp=rospy.Time.now())
        header.frame_id = self.camera_frame
        image_temp.height = height
        image_temp.width = width
        image_temp.encoding = 'bgr8'
        image_temp.data = np.array(imgdata).tobytes()
        image_temp.header = header
        image_temp.step = width * 3
        self.image_pub.publish(image_temp)

def main():
    rospy.init_node('yolov8_ros', anonymous=True)
    yolo_dect = Yolo_Dect()
    rospy.spin()

if __name__ == "__main__":
    main()

5.3 利用.onnx文件进行检测

csharp 复制代码
#!/home/wyh/.conda/envs/yolov8/bin/python3.8
# -*- coding: utf-8 -*-
import onnxruntime as rt
import numpy as np
import cv2
import matplotlib.pyplot as plt


# 定义类别标签
CLASS_NAMES = ['corrosion','craze', 'hide_craze','surface_attach','surface_corrosion','surface_eye',
        'surface_injure','surface_oil','thunderstrike']  # 请根据你的模型定义实际的类标签

COLOR_MAP = {
    "label_0": (255, 0, 0),       # 红色
    "label_1": (0, 255, 0),       # 绿色
    "label_2": (0, 0, 255),       # 蓝色
    "label_3": (255, 255, 0),     # 黄色
    "label_4": (255, 0, 255),     # 品红色
    "label_5": (0, 255, 255),     # 青色
    "label_6": (128, 0, 128),     # 紫色
    "label_7": (255, 165, 0),     # 橙色
    "label_8": (128, 128, 128),   # 灰色
}

def nms(pred, conf_thres, iou_thres): 
    conf = pred[..., 4] > conf_thres
    box = pred[conf == True] 
    cls_conf = box[..., 5:]
    cls = []
    for i in range(len(cls_conf)):
        cls.append(int(np.argmax(cls_conf[i])))
    total_cls = list(set(cls))  
    output_box = []  
    for i in range(len(total_cls)):
        clss = total_cls[i] 
        cls_box = []
        for j in range(len(cls)):
            if cls[j] == clss:
                box[j][5] = clss
                cls_box.append(box[j][:6])
        cls_box = np.array(cls_box)
        box_conf = cls_box[..., 4]  
        box_conf_sort = np.argsort(box_conf) 
        max_conf_box = cls_box[box_conf_sort[len(box_conf) - 1]]
        output_box.append(max_conf_box) 
        cls_box = np.delete(cls_box, 0, 0) 
        while len(cls_box) > 0:
            max_conf_box = output_box[len(output_box) - 1]  
            del_index = []
            for j in range(len(cls_box)):
                current_box = cls_box[j]  
                interArea = getInter(max_conf_box, current_box)  
                iou = getIou(max_conf_box, current_box, interArea)  
                if iou > iou_thres:
                    del_index.append(j)  
            cls_box = np.delete(cls_box, del_index, 0)  
            if len(cls_box) > 0:
                output_box.append(cls_box[0])
                cls_box = np.delete(cls_box, 0, 0)
    return output_box
 
def getIou(box1, box2, inter_area):
    box1_area = box1[2] * box1[3]
    box2_area = box2[2] * box2[3]
    union = box1_area + box2_area - inter_area
    iou = inter_area / union
    return iou
 
def getInter(box1, box2):
    box1_x1, box1_y1, box1_x2, box1_y2 = box1[0] - box1[2] / 2, box1[1] - box1[3] / 2, \
                                         box1[0] + box1[2] / 2, box1[1] + box1[3] / 2
    box2_x1, box2_y1, box2_x2, box2_y2 = box2[0] - box2[2] / 2, box2[1] - box1[3] / 2, \
                                         box2[0] + box2[2] / 2, box2[1] + box2[3] / 2
    if box1_x1 > box2_x2 or box1_x2 < box2_x1:
        return 0
    if box1_y1 > box2_y2 or box1_y2 < box2_y1:
        return 0
    x_list = [box1_x1, box1_x2, box2_x1, box2_x2]
    x_list = np.sort(x_list)
    x_inter = x_list[2] - x_list[1]
    y_list = [box1_y1, box1_y2, box2_y1, box2_y2]
    y_list = np.sort(y_list)
    y_inter = y_list[2] - y_list[1]
    inter = x_inter * y_inter
    return inter
 
# 画框并添加标签
def draw(img, xscale, yscale, pred):
    img_ = img.copy()
    if len(pred):
        for detect in pred:
            label = int(detect[5])  # 获取类别标签
            label_name = CLASS_NAMES[label]  # 通过类索引获取类名
            detect_coords = [int((detect[0] - detect[2] / 2) * xscale), 
                                int((detect[1] - detect[3] / 2) * yscale),
                                int((detect[0] + detect[2] / 2) * xscale), 
                                int((detect[1] + detect[3] / 2) * yscale)]
            
            # 获取颜色,如果没有对应的颜色,就使用默认颜色
            color = COLOR_MAP.get(label_name, (255, 255, 255))  # 默认为白色

            # 绘制矩形框
            img_ = cv2.rectangle(img_, (detect_coords[0], detect_coords[1]), 
                                  (detect_coords[2], detect_coords[3]), color, 2)

            # 绘制标签
            img_ = cv2.putText(img_, label_name, (detect_coords[0], detect_coords[1]-5), 
                               cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)

    return img_
 
if __name__ == '__main__':
    height, width = 640, 640
    img0 = cv2.imread('/home/wyh/artrc_catkin/src/artrc_yolov8/image/60.jpg')
    x_scale = img0.shape[1] / width
    y_scale = img0.shape[0] / height
    img = img0 / 255.
    img = cv2.resize(img, (width, height))
    img = np.transpose(img, (2, 0, 1))
    data = np.expand_dims(img, axis=0)
    sess = rt.InferenceSession('/home/wyh/artrc_catkin/src/artrc_yolov8/weights/best.onnx')
    input_name = sess.get_inputs()[0].name
    label_name = sess.get_outputs()[0].name
    pred = sess.run([label_name], {input_name: data.astype(np.float32)})[0]
    pred = np.squeeze(pred)
    pred = np.transpose(pred, (1, 0))

    pred_class = pred[..., 4:]
    pred_conf = np.max(pred_class, axis=-1)
    pred = np.insert(pred, 4, pred_conf, axis=-1)
    result = nms(pred, 0.3, 0.45)
    ret_img = draw(img0, x_scale, y_scale, result)

    # 使用OpenCV显示图像
    cv2.imshow('Detection Result', ret_img)
    cv2.waitKey(0)  # 等待按键事件
    cv2.destroyAllWindows()  # 关闭所有OpenCV窗口

5.3 利用.trt文件进行检测

csharp 复制代码
#include <ros/ros.h>
#include <image_transport/image_transport.h>
#include <cv_bridge/cv_bridge.h>
#include <sensor_msgs/image_encodings.h>
#include <std_msgs/Header.h>
#include <opencv2/opencv.hpp>
#include "../include/artrc_yolov8/yolo.h"
#include "../include/artrc_yolov8/yolov8.h"
#include <NvInfer.h>
#include <NvUtils.h>
#include <opencv2/opencv.hpp>
#include <opencv2/core.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>
#include <fstream>
#include "../include/artrc_yolov8/yolov8_trt.h"
#include <mutex>
cv::Mat image_;
namespace artrc_yolov8
{
	YoloResultData::~YoloResultData(){
		;
	}
	void YoloResultData::init(){
		ros::NodeHandle nh;
		img_receive_sub_ = nh.subscribe("/usb_camera/image_raw",1,&YoloResultData::image_receive_callback,this);
		img_detect_pub_ = nh.advertise<sensor_msgs::Image>("/detect_img", 1);
		boundingbox_result_pub_ = nh.advertise<artrc_yolov8::boundingbox_result_msgs>("/boundingbox_result",1);
	}
	void YoloResultData::processImage() 
	{
        if (!image_.empty()) {
            // 在这里处理图像
            cv::Mat processedImage = image_.clone(); // 例如,你可以对图像进行一些处理

            // 显示图像
            // cv::imshow("Processed Image", processedImage);
            // cv::waitKey(30); // 等待30毫秒
        } else {
            ROS_WARN("No image received yet.");
        }
    }
	// 设置检测内参数
	void YoloResultData::setParameters(utils::InitParameter& initParameters)
	{
		initParameters.class_names = utils::dataSets::coco80;
		initParameters.num_class = 80; // for coco
		initParameters.batch_size = 8;
		initParameters.dst_h = 640;
		initParameters.dst_w = 640;
		initParameters.input_output_names = { "images",  "output0" };
		initParameters.conf_thresh = 0.25f;
		initParameters.iou_thresh = 0.45f;
		initParameters.save_path = "";
	}
	//  yolo模型预测
	void YoloResultData::task(YOLOV8& yolo, const utils::InitParameter& param, std::vector<cv::Mat>& imgsBatch, const int& delayTime, const int& batchi,
           const bool& isShow, const bool& isSave)
	{
		if (imgsBatch.empty()) {
			std::cerr << "Input image batch is empty." << std::endl;
			return;
		}
		// std::cout<< "--------------------------------"<< std::endl;
		std::clock_t start = std::clock();
		utils::DeviceTimer d_t0; yolo.copy(imgsBatch);	      float t0 = d_t0.getUsedTime();
		utils::DeviceTimer d_t1; yolo.preprocess(imgsBatch);  float t1 = d_t1.getUsedTime();
		utils::DeviceTimer d_t2; yolo.infer();				  float t2 = d_t2.getUsedTime();
		utils::DeviceTimer d_t3; yolo.postprocess(imgsBatch); float t3 = d_t3.getUsedTime();
		std::clock_t end = std::clock();

    // 计算时间差
		double duration = static_cast<double>(end - start) / CLOCKS_PER_SEC;

		// 输出运行时间
		// std::cout << "程序运行时间: " << duration << " 秒" << std::endl;
		// std::cout << "delayTime"<< delayTime << std::endl;
		if(isShow)
			utils::show(yolo.getObjectss(), param.class_names, delayTime, imgsBatch);
		// if(isSave)
		// 	utils::save(yolo.getObjectss(), param.class_names, param.save_path, imgsBatch, param.batch_size, batchi);
		// 在终端输出检测结果
		YoloResultData::result_show(yolo, param, t1, t2, t3);
		

		// YoloResultData::result_show(yolo, param);
		// // std::cout<<"77777777777777777"<<std::endl;
		for (size_t bi = 0; bi < imgsBatch.size(); bi++)
		{
			cv_bridge::CvImagePtr cv_ptr(new cv_bridge::CvImage);
			cv_ptr->image = imgsBatch[bi];
			cv_ptr->encoding = "bgr8";
			img_detect_pub_.publish(cv_ptr->toImageMsg());
		}
		yolo.reset();
	}
	// 显示输出结果
	void YoloResultData::result_show(const YOLOV8& yolo, const utils::InitParameter& param, float t1, float t2, float t3) 
	// void YoloResultData::result_show(const YOLOV8& yolo, const utils::InitParameter& param) 
	{
		const auto& objectss = yolo.getObjectss();
		for (size_t bi = 0; bi < objectss.size(); bi++)
		{
			for (const auto& box : objectss[bi])
			{
				// std::cout<< "preprocess time:"<< t1 / param.batch_size <<";   "  
				// "infer time:"<< t2 / param.batch_size << ";   "  
				// "postprocess time:"<<t3 / param.batch_size<<std::endl;
				// std::cout << "Image " << bi << ": Detected box - "
				// std::cout << "Label: " << param.class_names[box.label] << ", "
				// 		<< "Confidence: " << box.confidence << ", "
				// 		<< "Bounding Box: [" << box.left << ", "
				// 		<< box.top << ", "
				// 		<< box.right << ", "
				// 		<< box.bottom << "]" << std::endl;
				pub_msg_.label = param.class_names[box.label];
				pub_msg_.confidence = box.confidence;
				pub_msg_.xmin = box.left;
				pub_msg_.xmax = box.right;
				pub_msg_.ymin = box.top;
				// 填充边界框数组
				pub_msg_.bounding_box.clear();  // 确保清空之前的数据
				pub_msg_.bounding_box.push_back(box.left);
				pub_msg_.bounding_box.push_back(box.top);
				pub_msg_.bounding_box.push_back(box.right);
				pub_msg_.bounding_box.push_back(box.bottom);
				boundingbox_result_pub_.publish(pub_msg_);
			}
		}
	}
	// 订阅图像数据
	void YoloResultData::image_receive_callback(const sensor_msgs::Image& image_msg){
		cv_bridge::CvImagePtr cv_ptr;
        try {
            cv_ptr = cv_bridge::toCvCopy(image_msg, sensor_msgs::image_encodings::BGR8);
            // 处理图像(例如显示)
			image_ = cv_ptr->image;
        	// cv::imshow("Image", cv_ptr->image);
            // cv::waitKey(30); // 等待30毫秒
		} catch (cv_bridge::Exception& e) {
            ROS_ERROR("cv_bridge exception: %s", e.what());
            return;
        }
	}
}

int main(int argc, char** argv)
{
    ros::init(argc, argv, "yolov8_ros_node");
	artrc_yolov8::YoloResultData YoloResultData_node;
	YoloResultData_node.init();

	utils::InitParameter param;
	YoloResultData_node.setParameters(param);

	std::string model_path = "/home/wyh/artrc_catkin/src/artrc_yolov8/weights/yolov8n.trt";//加载模型
	std::string video_path = "/home/wyh/artrc_catkin/src/artrc_yolov8/image/行人视频.mp4";
	std::string image_path = "/home/wyh/artrc_catkin/src/artrc_yolov8/image/6406406.jpg";

	int camera_id = 0;
	//  get input 输入源 判断
	utils::InputStream source;
	source = utils::InputStream::IMAGE;
	// source = utils::InputStream::VIDEO;
	// source = utils::InputStream::CAMERA;
	// source = utils::InputStream::TOPIC_IMAGE;

	// update params from command line parser
	int size = -1; // w or h
	int batch_size = 8;
	bool is_show = false;
	bool is_save = false;
	int total_batches = 0;
	int delay_time = 50;

	// / 从参数服务器获取参数
	ros::param::get("~size", size);
    ros::param::get("~batch_size", batch_size);
    ros::param::get("~show", is_show);
	// 参数赋值
	param.dst_h = param.dst_w = size;
    param.batch_size = batch_size;
    param.is_show = is_show;
	// cv::VideoCapture capture(1);
	cv::VideoCapture capture(1);

	if (!setInputStream(source, image_path, video_path, camera_id,
		capture, total_batches, delay_time, param))
	{
		sample::gLogError << "read the input data errors!" << std::endl;
		return -1;
		
	}
	std::vector<unsigned char> trt_file = utils::loadModel(model_path);
	
	// // read model
	if (trt_file.empty()){	
		std::cout << "trt_file is empty!" << std::endl;
	}
	else{
		std::cout << "trt_file is load!" << std::endl;}
	
	YOLOV8 yolo(param);
	
	// // init model
	if (!yolo.init(trt_file)){
		std::cout << "initEngine() ocur errors!" << std::endl;
	}
	else{
		std::cout << "initEngine() ocur success!" << std::endl;	
	}
	yolo.check();

	std::vector<cv::Mat> imgs_batch;
	imgs_batch.reserve(param.batch_size);
	int batchi = 0;
	cv::Mat frame;
	ros::Rate rate(50);

	while (ros::ok())
	{
		// std::cout << "imgs_batch_" << imgs_batch.size() << ";"<< "batch_size" << param.batch_size << std::endl;
		if (imgs_batch.size() < param.batch_size) // get input
		{	
			if (source == utils::InputStream::VIDEO)
            {
                capture.read(frame);
				// std::cout<<"00000_video"<< std::endl;
            }
			else if (source == utils::InputStream::CAMERA)
            {
                capture.read(frame);
				// std::cout<<"11111_camera"<< std::endl;
            }
            else if (source == utils::InputStream::IMAGE)
            {
				// std::cout<<"22222_image"<< std::endl;
                // frame = cv::imread(image_path);// 获取图像数据
				frame = YoloResultData_node.image_;	
            }
            else 
            {
				// std::cout<<"33333_topic"<<std::endl;
				frame = YoloResultData_node.image_;	
            }
			if (!frame.empty())
			{
                imgs_batch.emplace_back(frame.clone());
            }
            else
            {
				int delay_time = 5;
                sample::gLogWarning << "no more video or camera frame" << std::endl;
                YoloResultData_node.task(yolo, param, imgs_batch, delay_time, batchi, is_show, is_save);
                imgs_batch.clear();
                batchi++;
            }
        }

		else
		{
			int delay_time = 1;
			YoloResultData_node.task(yolo, param, imgs_batch, delay_time, batchi, is_show, is_save);
			imgs_batch.clear();
			batchi++;
		}
		ros::spinOnce();  // Handle all callbacks
		rate.sleep();     // Sleep for a while before next loop iteration
	}
	// ros::spin();
    return 0;
}
csharp 复制代码
# 将下述程序参数更改为自己类别
initParameters.class_names = utils::dataSets::coco80;
initParameters.num_class = 80; 
# 将权重文件替换为相应的文件
相关推荐
辛勤的程序猿22 分钟前
YOLO即插即用---PConv
深度学习·yolo·计算机视觉
富士达幸运星7 小时前
YOLOv4的网络架构解析
人工智能·yolo·目标跟踪
张小生18011 小时前
《YOLO 目标检测》—— YOLO v4 详细介绍
人工智能·python·yolo
一勺汤18 小时前
YOLOv8模型改进 第十七讲 通道压缩的自注意力机制CRA
yolo·目标检测·outlook·模块·yolov8·yolov8改进·魔改
方方爱学习20 小时前
解决一键重命名所有文件问题
人工智能·深度学习·yolo·机器学习
Limiiiing1 天前
YOLOv10改进策略【卷积层】| 利用MobileNetv4中的UIB、ExtraDW优化C2fCIB
深度学习·yolo·目标检测·计算机视觉
阿_旭2 天前
YOLOv11模型架构以及使用命令介绍
人工智能·深度学习·yolo·ai·yolo11
极智视界2 天前
无人机场景 - 目标检测数据集 - 夜间车辆检测数据集下载「包含VOC、COCO、YOLO三种格式」
yolo·目标检测·车辆检测·voc·夜间车辆检测·算法训练·无人机场景数据集
小张贼嚣张2 天前
yolov8涨点系列之损失函数替换
yolo·目标检测
阿_旭2 天前
基于YOLO11/v10/v8/v5深度学习的危险驾驶行为检测识别系统设计与实现【python源码+Pyqt5界面+数据集+训练代码】
人工智能·深度学习·yolo·目标检测·ai