LocateAnything 视觉-语言定位推理 多GPU并行大批量图片目标检测的实现

LocateAnything 视觉-语言定位推理 多GPU并行大批量图片目标检测的实现

flyfish

Fast and High-Quality Vision-Language Grounding with Parallel Box Decoding

参考地址

bash 复制代码
https://research.nvidia.com/labs/lpr/locate-anything/

已实现的功能

  1. 批量目标检测
    自动读取指定文件夹内的所有图片,使用 LocateAnything 模型 + 自定义提示词,批量检测图片中是否存在指定目标(如人、车、物体等)。
  2. 有效图片自动过滤
    自动剔除尺寸不符合要求、图片过小、损坏无法加载的无效图片,只保留符合模型输入要求的图片进行推理。
  3. 检测结果自动归档
    检测到目标的图片 自动复制保存到输出目录,且自动分文件夹存储(避免单个文件夹文件过多),未检测到目标的图片直接跳过。

考虑到的

  1. 多GPU并行加速
    支持指定多张GPU(如 0,1,2),自动将图片均匀分配到每个GPU,实现多卡并发推理,最大化利用硬件性能,处理海量图片。
  2. 缓存加速,避免重复工作
    自动缓存图片扫描结果,重复运行程序时无需重新扫描图片,直接读取缓存,大幅提升启动速度。
  3. 批量推理优化
    支持批次(Batch)处理图片,而非逐张推理,显著提升推理吞吐量。
  4. 图片预处理降显存
    自动缩放超大图片,适配模型输入,降低GPU显存占用。
    执行命令
bash 复制代码
 python main.py --config ./config.yaml --gpu 0 1 --scan-workers 8

代码如下

main.py

python 复制代码
import torch
import os
import time
import logging
import argparse
import yaml
import json
import hashlib
import sys
import io
from dataclasses import dataclass
from typing import List, Tuple, Optional
from PIL import Image
from multiprocessing import Queue
from queue import Empty
from concurrent.futures import ProcessPoolExecutor
import shutil
from tqdm import tqdm

# 模型工作类
from locateanything_worker import LocateAnythingWorker

# ====================== 日志配置模块 ======================
def setup_logging(log_file: str, log_level: str):
    """
    配置全局日志系统
    Args:
        log_file: 日志保存文件路径
        log_level: 日志级别 (INFO/DEBUG/WARN/ERROR)
    Returns:
        日志实例
    """
    # 将字符串日志级别转为对应数值
    numeric_level = getattr(logging, log_level.upper(), None)
    if not isinstance(numeric_level, int):
        numeric_level = logging.INFO

    # 配置日志:同时输出到控制台和文件,UTF-8编码防止中文乱码
    logging.basicConfig(
        level=numeric_level,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[
            logging.StreamHandler(),       # 控制台输出
            logging.FileHandler(log_file, encoding="utf-8")  # 文件输出
        ]
    )
    return logging.getLogger(__name__)


# ====================== 配置数据类(类型安全) ======================
@dataclass
class InferenceConfig:
    """推理配置类:封装所有推理参数,避免字典硬编码"""
    model_path: str                  # 模型文件路径
    input_image_dir: str             # 输入图片目录
    output_dir: str                  # 输出保存目录
    allowed_image_sizes: List[Tuple[int, int]]  # 允许的图片尺寸列表
    gpu_list: List[int]              # 使用的GPU编号列表
    max_images_per_folder: int       # 每个文件夹最大保存图片数(自动分文件夹)
    detection_prompt: str            # 检测提示词
    prompt_file: str                 # 提示词文件路径
    load_prompt_from_file: bool      # 是否从文件加载提示词
    batch_size: int                  # 推理批次大小
    min_image_size: Tuple[int, int]  # 图片最小尺寸限制
    max_side: int                    # 图片最大边长(缩放用)
    log_file: str                    # 日志文件路径
    log_level: str                   # 日志级别
    scan_cache_file: str = "./scan_cache.json"  # 图片扫描缓存文件
    scan_workers: int = 4            # 图片扫描并行进程数
    show_progress: bool = True       # 是否显示GPU进度条
    progress_bar_position: bool = True  # 多GPU进度条是否独立显示


def load_config(config_path: str) -> InferenceConfig:
    """
    从YAML配置文件加载参数
    Args:
        config_path: 配置文件路径
    Returns:
        配置类实例
    """
    with open(config_path, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)

    # 解析YAML参数,转为配置类
    return InferenceConfig(
        model_path=cfg["model_path"],
        input_image_dir=cfg["input_image_dir"],
        output_dir=cfg["output_dir"],
        allowed_image_sizes=[tuple(s) for s in cfg["allowed_image_sizes"]],
        gpu_list=cfg["gpu_list"],
        max_images_per_folder=cfg["max_images_per_folder"],
        detection_prompt=cfg["detection_prompt"],
        prompt_file=cfg["prompt_file"],
        load_prompt_from_file=cfg["load_prompt_from_file"],
        batch_size=cfg["batch_size"],
        min_image_size=tuple(cfg["min_image_size"]),
        max_side=cfg["max_side"],
        log_file=cfg["log_file"],
        log_level=cfg["log_level"],
        scan_cache_file=cfg.get("scan_cache_file", "./scan_cache.json"),
        scan_workers=cfg.get("scan_workers", 4),
        show_progress=cfg.get("show_progress", True),
        progress_bar_position=cfg.get("progress_bar_position", True),
    )


# ====================== 图片扫描缓存模块(避免重复扫描) ======================
def get_dir_signature(input_dir: str, allowed_sizes: List[Tuple[int, int]]) -> str:
    """
    生成目录唯一签名,用于校验缓存是否有效
    签名包含:目录路径、允许尺寸、修改时间、文件数量
    Returns:
        MD5签名字符串
    """
    dir_stat = os.stat(input_dir)
    # 统计目录下图片数量
    file_count = len([f for f in os.listdir(input_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))])

    # 组装签名数据
    signature_data = {
        "dir": input_dir,
        "allowed_sizes": sorted(allowed_sizes),
        "mtime": dir_stat.st_mtime,
        "file_count": file_count,
    }
    # 序列化为字符串并生成MD5
    signature_str = json.dumps(signature_data, sort_keys=True)
    return hashlib.md5(signature_str.encode()).hexdigest()


def save_scan_cache(cache_file: str, valid_images: List[str], input_dir: str,
                    allowed_sizes: List[Tuple[int, int]], stats: dict):
    """保存图片扫描结果到缓存文件"""
    cache_data = {
        "signature": get_dir_signature(input_dir, allowed_sizes),
        "valid_images": valid_images,    # 合法图片路径列表
        "stats": stats,                  # 扫描统计
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    with open(cache_file, "w", encoding="utf-8") as f:
        json.dump(cache_data, f, ensure_ascii=False, indent=2)


def load_scan_cache(cache_file: str, input_dir: str,
                    allowed_sizes: List[Tuple[int, int]]) -> Optional[dict]:
    """
    加载扫描缓存
    Returns:
        有效缓存数据 / None(缓存失效/不存在)
    """
    if not os.path.exists(cache_file):
        return None

    try:
        with open(cache_file, "r", encoding="utf-8") as f:
            cache_data = json.load(f)

        # 校验目录签名,判断缓存是否有效
        current_signature = get_dir_signature(input_dir, allowed_sizes)
        if cache_data.get("signature") != current_signature:
            return None

        return {
            "valid_images": cache_data.get("valid_images", []),
            "stats": cache_data.get("stats", {}),
        }
    except Exception:
        # 缓存文件损坏,返回None
        return None


# ====================== 命令行参数解析 ======================
def parse_args():
    """解析命令行参数,支持覆盖配置文件参数"""
    parser = argparse.ArgumentParser(description="LocateAnything 批量推理 pipeline")
    parser.add_argument("--config", type=str, default="./config.yaml", help="配置文件路径")
    parser.add_argument("--prompt", type=str, default=None, help="检测提示词(覆盖配置文件)")
    parser.add_argument("--gpu", type=int, nargs="+", default=None, help="指定GPU列表(覆盖配置文件)")
    parser.add_argument("--batch-size", type=int, default=None, help="批处理大小(覆盖配置文件)")
    parser.add_argument("--scan-workers", type=int, default=None, help="图片扫描进程数(覆盖配置文件)")
    parser.add_argument("--no-progress", action="store_true", help="关闭进度条")
    return parser.parse_args()


# ====================== 多进程并行扫描图片 ======================
def scan_single_image(args: Tuple) -> Tuple[str, str]:
    """
    单张图片校验(多进程子任务)
    Returns:
        (状态, 图片路径)
        状态:valid(合法)/size_mismatch(尺寸不符)/too_small(尺寸过小)/load_error(加载失败)
    """
    img_path, allowed_sizes, min_size = args
    try:
        with Image.open(img_path) as img:
            # 校验图片尺寸是否在允许列表
            if img.size not in allowed_sizes:
                return ("size_mismatch", img_path)
            # 校验图片是否小于最小尺寸
            if img.size[0] < min_size[0] or img.size[1] < min_size[1]:
                return ("too_small", img_path)
            # 合法图片
            return ("valid", img_path)
    except Exception:
        # 图片损坏/无法加载
        return ("load_error", img_path)


def scan_images_parallel(input_dir: str, allowed_sizes: List[Tuple[int, int]],
                         min_size: Tuple[int, int], num_workers: int = 4) -> Tuple[List[str], dict]:
    """
    多进程并行扫描目录,过滤无效图片
    Returns:
        (合法图片路径列表, 扫描统计字典)
    """
    # 收集所有图片文件
    all_images = []
    for filename in os.listdir(input_dir):
        if filename.lower().endswith((".jpg", ".jpeg", ".png")):
            all_images.append(os.path.join(input_dir, filename))

    if not all_images:
        return [], {"size_mismatch": 0, "too_small": 0, "load_error": 0, "valid": 0}

    # 组装子任务参数
    scan_args = [(img, allowed_sizes, min_size) for img in all_images]

    valid_images = []
    stats = {"size_mismatch": 0, "too_small": 0, "load_error": 0}

    # 多进程并行执行
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        results = executor.map(scan_single_image, scan_args, chunksize=100)

        for status, img_path in results:
            if status == "valid":
                valid_images.append(img_path)
            else:
                stats[status] += 1

    return valid_images, stats


# ====================== 图片预处理工具 ======================
def resize_to_max_side(img: Image.Image, max_side: int, min_size: Tuple[int, int]) -> Image.Image:
    """
    图片自适应缩放:长边不超过max_side,且不小于最小尺寸
    用于适配模型输入,降低显存占用
    """
    width, height = img.size

    # 小于最小尺寸,不缩放
    if width < min_size[0] or height < min_size[1]:
        return img

    # 长边小于限制,不缩放
    if max(width, height) <= max_side:
        return img

    # 计算缩放比例
    scale = max_side / max(width, height)
    new_width = int(width * scale)
    new_height = int(height * scale)

    # 高质量缩放,兼容新旧PIL版本
    try:
        return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    except Exception:
        return img.resize((new_width, new_height), Image.BILINEAR)


# ====================== 独立IO保存进程(解耦推理与文件IO) ======================
def io_worker(save_queue: Queue, output_root: str, max_count: int, counter_dict: dict, lock):
    """
    独立的文件保存进程(当前主流程未启用)
    作用:将图片保存从推理进程剥离,避免IO阻塞推理
    """
    while True:
        try:
            task = save_queue.get(timeout=1)
            if task is None:  # 终止信号
                break

            source_path, filename = task
            current_count = counter_dict["count"]
            # 自动分文件夹
            batch_number = (current_count // max_count) + 1
            save_folder = f"{output_root}_batch_{batch_number}"
            os.makedirs(save_folder, exist_ok=True)
            target_path = os.path.join(save_folder, filename)

            # 复制图片
            shutil.copy2(source_path, target_path)

            # 线程安全计数
            with lock:
                counter_dict["count"] += 1

        except Empty:
            continue
        except Exception as e:
            logging.error(f"IO保存失败:{e}")


# ====================== 核心推理函数(单GPU执行) ======================
def process_batch(args: Tuple) -> dict:
    """
    单GPU批次推理主函数
    Args:
        args: 包含GPU编号、图片列表、提示词、模型路径等所有参数
    Returns:
        单GPU推理统计字典
    """
    # 解包参数
    gpu_id, image_paths, prompt, model_path, batch_size, max_side, min_size, show_progress, progress_bar_position, output_root, max_images_per_folder = args

    # 指定当前进程使用的GPU
    torch.cuda.set_device(gpu_id)
    device = f"cuda:{gpu_id}"

    # GPU独立日志
    logger = logging.getLogger(f"GPU{gpu_id}")
    logger.info(f"初始化模型 | 设备:{device} | 待处理图片:{len(image_paths)} 张")

    # 加载模型
    worker = LocateAnythingWorker(model_path, device=device)
    torch.cuda.empty_cache()  # 清空显存缓存
    logger.info("模型加载完成,开始推理")

    # 初始化统计指标
    stats = {
        "total": len(image_paths),
        "detected": 0,      # 检测到目标数量
        "skipped": 0,       # 跳过数量
        "failed": 0,        # 失败数量
        "saved": 0,         # 保存数量
        "batch_times": [],  # 批次耗时
        "avg_batch_time": 0.0,
    }

    saved_count = 0
    os.makedirs(output_root, exist_ok=True)

    # 初始化GPU独立进度条(输出到stderr,避免和模型日志冲突)
    pbar = None
    if show_progress:
        pbar = tqdm(
            total=len(image_paths), 
            desc=f"GPU{gpu_id}", 
            position=gpu_id if progress_bar_position else 0, 
            leave=True,
            file=sys.stderr,
            bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]"
        )

    # 批次循环处理图片
    for i in range(0, len(image_paths), batch_size):
        batch_start_time = time.time()
        batch_paths = image_paths[i:i + batch_size]
        batch_images = []

        # 加载并预处理当前批次图片
        for img_path in batch_paths:
            try:
                img = Image.open(img_path).convert("RGB")
                img = resize_to_max_side(img, max_side, min_size)
                batch_images.append(img)
            except Exception as e:
                logger.warning(f"图片加载失败:{img_path} | 错误:{e}")
                stats["skipped"] += 1
                if pbar:
                    pbar.update(1)

        if not batch_images:
            continue

        # 执行批量推理(重定向stdout,静默模型原生日志)
        try:
            old_stdout = sys.stdout
            sys.stdout = io.StringIO()
            try:
                results = worker.ground_batch(batch_images, prompt, batch_size=len(batch_images))
            finally:
                sys.stdout = old_stdout
        except AttributeError:
            # 兼容无batch接口的模型,逐张推理
            results = []
            for img in batch_images:
                try:
                    result = worker.ground_multi(img, prompt)
                    results.append(result)
                except Exception as e:
                    logger.warning(f"推理失败:{e}")
                    results.append({"answer": ""})

        # 记录批次耗时
        batch_time = time.time() - batch_start_time
        stats["batch_times"].append(batch_time)

        # 处理推理结果
        for img_path, result in zip(batch_paths, results):
            filename = os.path.basename(img_path)
            answer = result.get("answer", "") if isinstance(result, dict) else str(result)

            # 判断是否检测到目标:包含<box>且不是None
            is_detected = "<box>" in answer and "<box>None</box>" not in answer

            if is_detected:
                stats["detected"] += 1
                # 自动分文件夹保存图片
                batch_num = saved_count // max_images_per_folder + 1
                save_folder = f"{output_root}_batch_{batch_num}"
                os.makedirs(save_folder, exist_ok=True)
                target_path = os.path.join(save_folder, filename)
                shutil.copy2(img_path, target_path)
                saved_count += 1
                stats["saved"] += 1
                # 打印检测结果
                print(f"[GPU{gpu_id}] {filename} | 保存到:{os.path.basename(save_folder)} | 检测结果:{answer}")
                logger.info(f"检测到目标 | {filename} | 保存到:{os.path.basename(save_folder)}")
            else:
                # 无目标/推理失败,跳过
                stats["skipped"] += 1

        # 更新进度条
        if pbar:
            pbar.update(len(batch_paths))
            pbar.set_postfix(
                detected=f"{stats['detected']}",
                skipped=f"{stats['skipped']}",
                batch_time=f"{batch_time:.2f}s"
            )

    if pbar:
        pbar.close()

    # 计算平均批次耗时
    if stats["batch_times"]:
        stats["avg_batch_time"] = sum(stats["batch_times"]) / len(stats["batch_times"])

    logger.info(f"处理完成 | 总计:{stats['total']} | 检测:{stats['detected']} | 跳过:{stats['skipped']} | 失败:{stats['failed']} | 平均Batch:{stats['avg_batch_time']:.2f}s")
    return stats


# ====================== 主流程调度 ======================
def main():
    # 1. 解析命令行参数
    args = parse_args()

    # 2. 加载YAML配置
    config = load_config(args.config)

    # 3. 命令行参数覆盖配置文件(优先级更高)
    if args.prompt:
        config.detection_prompt = args.prompt
    if args.gpu:
        config.gpu_list = args.gpu
    if args.batch_size:
        config.batch_size = args.batch_size
    if args.scan_workers:
        config.scan_workers = args.scan_workers
    if args.no_progress:
        config.show_progress = False

    # 4. 初始化全局日志
    setup_logging(config.log_file, config.log_level)
    logger = logging.getLogger("main")

    logger.info("=" * 60)
    logger.info("LocateAnything 推理 Pipeline")
    logger.info("=" * 60)

    # 5. 加载检测提示词(文件/配置二选一)
    if config.load_prompt_from_file and os.path.exists(config.prompt_file):
        with open(config.prompt_file, "r", encoding="utf-8") as f:
            config.detection_prompt = f.read().strip()
        logger.info(f"从文件加载提示词:{config.detection_prompt}")
    else:
        logger.info(f"使用默认提示词:{config.detection_prompt}")

    # 6. 扫描有效图片(优先读取缓存,提速)
    cache_data = load_scan_cache(
        config.scan_cache_file,
        config.input_image_dir,
        config.allowed_image_sizes
    )

    if cache_data:
        # 缓存有效,直接读取
        logger.info("从缓存加载扫描结果...")
        valid_images = cache_data["valid_images"]
        stats = cache_data["stats"]
        skipped_size_mismatch = stats.get("size_mismatch", 0)
        skipped_small = stats.get("too_small", 0)
        skipped_error = stats.get("load_error", 0)
        logger.info(f"缓存加载完成 | 有效图片:{len(valid_images)} 张")
    else:
        # 缓存无效,多进程重新扫描
        num_scan_workers = args.scan_workers if args.scan_workers else config.scan_workers
        logger.info(f"多进程扫描有效图片中... (workers={num_scan_workers})")
        valid_images, scan_stats = scan_images_parallel(
            config.input_image_dir,
            config.allowed_image_sizes,
            config.min_image_size,
            num_workers=num_scan_workers
        )
        skipped_size_mismatch = scan_stats.get("size_mismatch", 0)
        skipped_small = scan_stats.get("too_small", 0)
        skipped_error = scan_stats.get("load_error", 0)
        logger.info(f"扫描完成 | 有效图片:{len(valid_images)} 张")

        # 保存新的扫描缓存
        save_scan_cache(
            config.scan_cache_file,
            valid_images,
            config.input_image_dir,
            config.allowed_image_sizes,
            scan_stats
        )
        logger.info(f"扫描结果已缓存到:{config.scan_cache_file}")

    # 打印图片跳过统计
    logger.info(f"跳过统计 | 尺寸不匹配:{skipped_size_mismatch} | 过小:{skipped_small} | 加载失败:{skipped_error}")

    # 无有效图片,直接退出
    if not valid_images:
        logger.warning("无有效图片,退出程序")
        return

    # 7. 负载均衡:将图片均匀分配给所有GPU
    gpu_num = len(config.gpu_list)
    base_num = len(valid_images) // gpu_num
    remainder = len(valid_images) % gpu_num  # 余数图片分配给前N个GPU

    image_splits = []
    start_idx = 0
    for i, gpu_id in enumerate(config.gpu_list):
        end_idx = start_idx + base_num + (1 if i < remainder else 0)
        image_splits.append((gpu_id, valid_images[start_idx:end_idx]))
        start_idx = end_idx

    # 8. 组装多进程推理参数
    process_args = [
        (gpu_id, images, config.detection_prompt, config.model_path,
         config.batch_size, config.max_side, config.min_image_size,
         config.show_progress, config.progress_bar_position,
         config.output_dir, config.max_images_per_folder)
        for gpu_id, images in image_splits
    ]

    logger.info(f"{gpu_num} 个 GPU 并发处理 | Batch Size:{config.batch_size}")
    total_start_time = time.time()

    # 9. 多进程启动多GPU并发推理
    all_stats = []
    with ProcessPoolExecutor(max_workers=gpu_num) as executor:
        futures = [executor.submit(process_batch, p_args) for p_args in process_args]
        for future in futures:
            try:
                stats = future.result()
                all_stats.append(stats)
            except Exception as e:
                logger.error(f"进程执行失败:{e}")

    # 10. 汇总所有GPU统计数据
    total_stats = {
        "total": sum(s["total"] for s in all_stats),
        "detected": sum(s["detected"] for s in all_stats),
        "skipped": sum(s["skipped"] for s in all_stats),
        "failed": sum(s["failed"] for s in all_stats),
        "saved": sum(s.get("saved", 0) for s in all_stats),
    }

    total_time = round(time.time() - total_start_time, 2)
    
    # 11. 打印最终汇总报告
    logger.info("=" * 70)
    logger.info("                       推理统计汇总")
    logger.info("=" * 70)
    logger.info(f"总耗时          : {total_time} 秒")
    logger.info(f"处理图片总数    : {total_stats['total']} 张")
    logger.info(f"检测到目标      : {total_stats['detected']} 张")
    logger.info(f"已保存图片      : {total_stats['saved']} 张")
    logger.info(f"跳过图片        : {total_stats['skipped']} 张")
    logger.info(f"失败图片        : {total_stats['failed']} 张")
    logger.info(f"平均速度        : {round(total_stats['total'] / total_time, 2)} 张/秒")
    logger.info(f"检测率          : {round(total_stats['detected'] / total_stats['total'] * 100, 2)}%")
    
    logger.info("-" * 70)
    logger.info("各 GPU 详细统计:")
    for gpu_idx, stats in enumerate(all_stats):
        gpu_id = config.gpu_list[gpu_idx]
        avg_batch = stats.get("avg_batch_time", 0)
        saved = stats.get("saved", 0)
        logger.info(f"  GPU{gpu_id}: 处理 {stats['total']} 张 | 检测 {stats['detected']} 张 | 保存 {saved} 张 | 平均 {avg_batch:.2f}s/batch")
    
    logger.info("=" * 70)


if __name__ == "__main__":
    # Windows/Linux兼容:多进程启动方式设置为spawn
    torch.multiprocessing.set_start_method("spawn", force=True)
    main()

config.yaml

yaml 复制代码
model_path: "/media/user//models/nv-community/LocateAnything-3B/"
input_image_dir: "/media/user/images"
output_dir: "/media/user/data/detected_person_images"

# 图片尺寸过滤
allowed_image_sizes:
  - [1920, 1080]
  - [1280, 720]
  - [2560, 1440]

# GPU配置
gpu_list: [0, 1]

# 存储配置
max_images_per_folder: 500

# 检测配置
detection_prompt: "people wearing red clothes"
prompt_file: "./detect_prompt.txt"
load_prompt_from_file: false

# Batch处理配置
batch_size: 2

# 缩放配置
min_image_size: [1280, 720]  # 最小尺寸限制
max_side: 1920  # 长边最大像素

# 日志配置
log_file: "./inference.log"
log_level: "INFO"

# 扫描缓存配置
scan_cache_file: "./scan_cache.json"  # 扫描结果缓存文件
scan_workers: 4  # 扫描图片时的并行进程数

# 进度条配置
show_progress: true  # 是否显示进度条
progress_bar_position: true  # 各 GPU 进度条独立位置显示

locateanything_worker.py

python 复制代码
# 正则表达式库,用于解析模型输出的检测框/点位字符串
import re
import torch
from PIL import Image
# Transformers库:加载预训练模型、分词器、图像处理器
from transformers import AutoModel, AutoTokenizer, AutoProcessor


class LocateAnythingWorker:
    """
    LocateAnything 模型推理工作类
    核心作用:一次性加载模型,持续处理视觉-语言定位请求(避免重复加载模型)
    支持:目标检测、短语定位、文本检测、GUI定位、点位标注等多种任务
    """

    def __init__(self, model_path: str, device: str = "cuda", dtype=torch.bfloat16):
        """
        初始化模型、分词器、图像处理器
        Args:
            model_path: 模型路径(本地路径或HuggingFace模型名)
            device: 运行设备 cuda(显卡)/cpu
            dtype: 模型精度 bfloat16(显存高效)/float32(高精度)
        """
        # 运行设备
        self.device = device
        # 模型计算精度
        self.dtype = dtype

        # 加载文本分词器:将文本转为模型可识别的token
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        # 加载多模态处理器:处理图片+文本的输入格式
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        # 加载LocateAnything核心模型,指定精度+设备,设置为评估模式(关闭训练特性)
        self.model = AutoModel.from_pretrained(
            model_path,
            torch_dtype=dtype,
            trust_remote_code=True,
        ).to(device).eval()

    @torch.no_grad()
    def predict(
        self,
        image: Image.Image,
        question: str,
        generation_mode: str = "hybrid",
        max_new_tokens: int = 2048,
        temperature: float = 0.7,
        verbose: bool = True,
    ) -> dict:
        """
        【推理方法】执行单张图片的视觉-语言定位推理
        Args:
            image: PIL格式图片(必须RGB通道)
            question: 任务提示词(检测/定位/文本等指令)
            generation_mode: 生成模式 fast(快速)/slow(精准)/hybrid(混合)
            max_new_tokens: 模型最大生成文本长度
            temperature: 生成温度 0=贪心解码(精准),值越高越随机
            verbose: 是否返回耗时统计信息
        Returns:
            结果字典: answer(模型输出文本), stats(统计信息,可选), history(生成历史,可选)
        """
        # 构造多模态输入消息格式:图片 + 文本提示
        messages = [
            {"role": "user", "content": [
                {"type": "image", "image": image},   # 输入图片
                {"type": "text", "text": question},  # 输入文本指令
            ]}
        ]

        # 应用聊天模板,将消息格式化为模型要求的输入文本
        text = self.processor.py_apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        # 处理图片/视频视觉信息
        images, videos = self.processor.process_vision_info(messages)
        # 预处理:将文本+图片转为模型可接收的张量格式
        inputs = self.processor(
            text=[text], images=images, videos=videos, return_tensors="pt"
        ).to(self.device)

        # 提取输入张量:像素值、文本token、图片网格信息
        pixel_values = inputs["pixel_values"].to(self.dtype)
        input_ids = inputs["input_ids"]
        image_grid_hws = inputs.get("image_grid_hws", None)

        # 【模型核心推理】生成定位结果
        response = self.model.generate(
            pixel_values=pixel_values,       # 图片像素张量
            input_ids=input_ids,             # 文本token
            attention_mask=inputs["attention_mask"],  # 注意力掩码
            image_grid_hws=image_grid_hws,   # 图片网格信息
            tokenizer=self.tokenizer,        # 分词器
            max_new_tokens=max_new_tokens,   # 最大生成长度
            use_cache=True,                  # 启用缓存加速
            generation_mode=generation_mode, # 生成模式
            temperature=temperature,         # 生成温度
            do_sample=True,                  # 启用采样生成
            top_p=0.9,                       # 核采样参数
            repetition_penalty=1.1,          # 重复惩罚系数
            verbose=verbose,                 # 输出统计信息
        )

        # 封装返回结果
        result = {"answer": response[0] if isinstance(response, tuple) else response}
        # 如果返回元组,额外添加生成历史和统计信息
        if isinstance(response, tuple) and len(response) >= 3:
            result["history"] = response[1]
            result["stats"] = response[2]
        return result

    # ===================== 便捷任务封装方法 =====================
    # 针对不同视觉-语言定位任务,封装固定提示词,简化调用

    def detect(self, image: Image.Image, categories: list[str], **kwargs) -> dict:
        """
        目标检测 / 文档布局分析
        Args:
            image: 输入图片
            categories: 检测类别列表 例: ["人", "车", "自行车"]
        """
        # 拼接类别字符串,构造检测提示词
        cats = "</c>".join(categories)
        prompt = f"Locate all the instances that matches the following description: {cats}."
        return self.predict(image, prompt, **kwargs)

    def ground_single(self, image: Image.Image, phrase: str, **kwargs) -> dict:
        """
        短语定位(仅定位单个目标)
        Args:
            image: 输入图片
            phrase: 描述文本 例: "穿红色衣服的人"
        """
        prompt = f"Locate a single instance that matches the following description: {phrase}."
        return self.predict(image, prompt, **kwargs)

    def ground_multi(self, image: Image.Image, phrase: str, **kwargs) -> dict:
        """
        短语定位(定位所有目标)→ 主流程使用的核心方法
        Args:
            image: 输入图片
            phrase: 描述文本 例: "所有红色的汽车"
        """
        prompt = f"Locate all the instances that match the following description: {phrase}."
        return self.predict(image, prompt, **kwargs)

    def ground_text(self, image: Image.Image, phrase: str, **kwargs) -> dict:
        """文本定位:定位图片中指定的文字内容"""
        prompt = f"Please locate the text referred as {phrase}."
        return self.predict(image, prompt, **kwargs)

    def detect_text(self, image: Image.Image, **kwargs) -> dict:
        """场景文本检测:检测图片中的所有文字并返回框坐标"""
        prompt = "Detect all the text in box format."
        return self.predict(image, prompt, **kwargs)

    def ground_gui(self, image: Image.Image, phrase: str, output_type: str = "box", **kwargs) -> dict:
        """
        GUI界面定位:定位按钮/输入框等界面元素
        Args:
            output_type: box(边框) / point(点位)
        """
        if output_type == "point":
            prompt = f"Point to: {phrase}."
        else:
            prompt = f"Locate the region that matches the following description: {phrase}."
        return self.predict(image, prompt, **kwargs)

    def point(self, image: Image.Image, phrase: str, **kwargs) -> dict:
        """点位标注:在图片上指出目标的中心点坐标"""
        prompt = f"Point to: {phrase}."
        return self.predict(image, prompt, **kwargs)

    # ===================== 工具方法:解析模型输出 =====================

    @staticmethod
    def parse_boxes(answer: str, image_width: int, image_height: int) -> list[dict]:
        """
        静态方法:解析模型输出的文本,转为【像素级边界框坐标】
        模型输出坐标规则:归一化整数 [0, 1000]
        Args:
            answer: 模型输出的结果文本
            image_width: 原图宽度
            image_height: 原图高度
        Returns:
            边框列表: [{"x1":左上x, "y1":左上y, "x2":右下x, "y2":右下y}, ...]
        """
        boxes = []
        # 正则匹配模型输出的框格式:<box><x1><y1><x2><y2></box>
        for m in re.finditer(r"<box><(\d+)><(\d+)><(\d+)><(\d+)></box>", answer):
            x1, y1, x2, y2 = [int(g) for g in m.groups()]
            # 归一化坐标 → 原图像素坐标
            boxes.append({
                "x1": x1 / 1000 * image_width,
                "y1": y1 / 1000 * image_height,
                "x2": x2 / 1000 * image_width,
                "y2": y2 / 1000 * image_height,
            })
        return boxes

    @staticmethod
    def parse_points(answer: str, image_width: int, image_height: int) -> list[dict]:
        """
        静态方法:解析模型输出的文本,转为【像素级点位坐标】
        Args:
            answer: 模型输出的结果文本
            image_width: 原图宽度
            image_height: 原图高度
        Returns:
            点位列表: [{"x": 横坐标, "y": 纵坐标}, ...]
        """
        points = []
        # 正则匹配模型输出的点位格式:<box><x><y></box>
        for m in re.finditer(r"<box><(\d+)><(\d+)></box>", answer):
            x, y = int(m.group(1)), int(m.group(2))
            # 归一化坐标 → 原图像素坐标
            points.append({
                "x": x / 1000 * image_width,
                "y": y / 1000 * image_height,
            })
        return points


# --------------- 使用示例(测试代码)---------------
if __name__ == "__main__":
    # 初始化工作类,加载模型
    worker = LocateAnythingWorker("nvidia/LocateAnything-3B")
    # 加载测试图片
    img = Image.open("example.jpg").convert("RGB")

    # 1. 目标检测示例
    result = worker.detect(img, ["person", "car", "bicycle"])
    print("检测结果:", result["answer"])

    # 2. 多目标短语定位示例
    result = worker.ground_multi(img, "people wearing red shirts")
    print("短语定位结果:", result["answer"])

    # 3. 场景文本检测示例
    result = worker.detect_text(img)
    print("文本检测结果:", result["answer"])

    # 4. 点位标注示例
    result = worker.point(img, "the traffic light")
    print("点位标注结果:", result["answer"])

    # 5. GUI界面定位(点位)示例
    result = worker.ground_gui(img, "the search button", output_type="point")
    print("GUI定位结果:", result["answer"])

    # 6. 解析模型输出,转为结构化坐标
    w, h = img.size
    boxes = LocateAnythingWorker.parse_boxes(result["answer"], w, h)
    print("解析后的像素坐标框:", boxes)
相关推荐
深度学习lover2 小时前
<数据集>yolo航拍视角垃圾识别<目标检测>
人工智能·深度学习·yolo·目标检测·数据集·航拍视角垃圾识别
动物园猫2 小时前
无人机灾害场景人体目标检测数据集分享(适用于YOLO系列深度学习分类检测任务)
yolo·目标检测·无人机
JeJe同学15 小时前
LabelImg报错:IndexError: list index out of range 解决方法
深度学习·目标检测
code_pgf21 小时前
PointPillars 3D 目标检测详解
人工智能·目标检测·3d
人工智能算法研究院1 天前
【目标检测论文解读复现NO.43】基于改进YOLOv10n的植物叶片病害轻量化检测模型
yolo·目标检测·目标跟踪
深度学习lover1 天前
<数据集>yolo月球陨石坑识别<目标检测>
人工智能·yolo·目标检测·计算机视觉·数据集·月球陨石坑识别
奔袭的算法工程师1 天前
论文解读--BEV-radar:: bidirectional radar-camera fusion for 3D object detection
人工智能·算法·目标检测·计算机视觉·自动驾驶·信号处理
AI学长1 天前
数据集|二维码目标检测QRCodeDetection
人工智能·目标检测·计算机视觉·二维码目标检测
深度学习lover1 天前
<数据集>yolo樱桃识别<目标检测>
人工智能·深度学习·yolo·目标检测·计算机视觉·数据集·樱桃识别