import os
import argparse
import glob
import cv2
import numpy as np
import onnxruntime
import tqdm
import pymysql
import time
import json
from datetime import datetime
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用 GPU 0
def get_connection():
"""创建并返回一个新的数据库连接。"""
数据库连接信息
host = 'localhost'
user = 'root'
password = '123456'
database = 'video_streaming_database'
return pymysql.connect(host=host, user=user, password=password, database=database)
def get_connection_results():
"""创建并返回一个新的数据库连接。"""
数据库连接信息
host = 'localhost'
user = 'root'
password = '123456'
database = 'results'
return pymysql.connect(host=host, user=user, password=password, database=database)
def ensure_connection(connection):
"""确保连接有效。如果连接无效,则重新建立连接。"""
if connection is None or not connection.open:
print("Connection is invalid or closed. Reconnecting...")
return get_connection()
return connection
def ensure_connection_results(connection):
"""确保连接有效。如果连接无效,则重新建立连接。"""
if connection is None or not connection.open:
print("Connection is invalid or closed. Reconnecting...")
return get_connection_results()
return connection
def get_parser():
parser = argparse.ArgumentParser(description="onnx model inference")
parser.add_argument(
"--model-path",
default=R"/home/hitsz/yk_workspace/Yolov5_track/weights/sbs_r50_0206_export_params_True.onnx",
help="onnx model path"
)
parser.add_argument(
"--input",
default="/home/hitsz/yk_workspace/Yolov5_track/test_4S_videos/test_yk1_det3/save_crops/test_yk1/person/1/*jpg",
nargs="+",
help="A list of space separated input images; "
"or a single glob pattern such as 'directory/*.jpg'",
)
parser.add_argument(
"--output",
default='/home/hitsz/yk_workspace/Yolov5_track/02_output_det/onnx_output',
help='path to save the output features'
)
parser.add_argument(
"--height",
type=int,
default=384,
help="height of image"
)
parser.add_argument(
"--width",
type=int,
default=128,
help="width of image"
)
return parser
def preprocess(image_path, image_height, image_width):
original_image = cv2.imread(image_path)
norm_mean = np.array([0.485, 0.456, 0.406])
norm_std = np.array([0.229, 0.224, 0.225])
normalized_img = (original_image / 255.0 - norm_mean) / norm_std
original_image = normalized_img[:, :, ::-1]
img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
img = img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w)
return img
def normalize(nparray, order=2, axis=-1):
"""Normalize a N-D numpy array along the specified axis."""
norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
return nparray / (norm + np.finfo(np.float32).eps)
data2 = []
if name == "main ":
args = get_parser().parse_args()
# 配置数据库连接
db_config = {
'host': 'localhost',
'user': 'root',
'password': '123456',
'database': 'video_streaming_database',
}
db_config_results = {
'host': 'localhost',
'user': 'root',
'password': '123456',
'database': 'results',
}
# 定义批处理大小
batch_size = 500
pre_end_frame_idx = 10000
# 连接到数据库
connection = pymysql.connect(**db_config)
connection_results = pymysql.connect(**db_config_results)
while True:
connection = ensure_connection(connection) # 确保连接有效
with connection.cursor() as cursor:
cursor.execute("SELECT MAX(id) FROM new_detection_tracking_results_1")
max_id = cursor.fetchone()[0]
print(max_id)
# 获取ID前面100条数据
if max_id is not None:
end_id = max(1, max_id-1)
cursor.execute(f"SELECT crop_image_path FROM new_detection_tracking_results_1 WHERE id = {end_id}")
crop_image_path = cursor.fetchall()
connection.commit()
connection.close()
if max_id is not None:
dir_path = os.path.dirname(os.path.dirname(crop_image_path[0][0]))
file_name = os.path.basename(crop_image_path[0][0])
cam_ip = file_name.split("_")[0]
end_frame_idx = int(file_name.split("_")[1]) - 1440
for i in range(pre_end_frame_idx, end_frame_idx):
json_path = os.path.join(dir_path, cam_ip + "_" + str(i).zfill(8) + "_track.json")
if not os.path.exists(json_path):
continue
creation_time = os.path.getctime(json_path)
# 转换为 '%Y-%m-%d %H:%M:%S' 格式
formatted_creation_time = datetime.fromtimestamp(creation_time).strftime('%Y-%m-%d %H:%M:%S')
# print(formatted_creation_time)
for j in range(48):
json_name_path = os.path.join(dir_path, cam_ip + "_" + str(i-j).zfill(8) + "_track_name.json")
if os.path.exists(json_name_path):
break
id_name = {}
if os.path.exists(json_name_path):
with open(json_name_path, 'r') as f1:
id_name = json.load(f1)
else:
continue
if os.path.exists(json_path[:-5]):
continue
if os.path.exists(json_path):
with open(json_path, 'r') as f:
tracking_data = json.load(f)
# 遍历跟踪结果,并绘制到图像上
for key in tracking_data.keys():
id = key
action = tracking_data[key][6]
if len(action.split("||")) == 0:
continue
elif len(action.split("||")) == 1:
action_show = action.split("||")[0]
else:
action_show = action.split("||")[0] + " " + action.split("||")[1]
if len(id_name) > 0 and key.zfill(4) in id_name.keys():
name = id_name[key.zfill(4)].split("_")[0] + ": 0." + id_name[key.zfill(4)].split("_")[-1][:2]
data2.append((
cam_ip,
int(end_frame_idx), \
int(key),\
name,\
action_show,\
formatted_creation_time
))
else:
name = ""
os.makedirs(json_path[:-5], exist_ok=True)
print('---------len(data2) is:',len(data2))
if len(data2) >= 500:
connection_results = ensure_connection_results(connection_results) # 确保连接有效
with connection_results.cursor() as cursor:
# 插入数据的SQL语句
insert_sql = """
INSERT INTO time_results (camera_ip, frame_number, tracking_id, matched_id, action_recognized, event_datetime)
VALUES (%s, %s, %s, %s, %s, %s);
"""
# 执行插入操作
cursor.executemany(insert_sql, data2)
connection_results.commit()
data2 = []
pre_end_frame_idx = end_frame_idx
time.sleep(5)