AA-Clip复现笔记

创建一个predict.py文件

python 复制代码
import argparse
import csv
import os
import re
from glob import glob

import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from torchvision import transforms

from dataset.constants import PROMPTS
from forward_utils import apply_ad_scoremap, calculate_similarity_map
from model.adapter import AdaptedCLIP
from model.clip import create_model
from model.tokenizer import tokenize


IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}


def collect_images(inputs: list[str]) -> list[str]:
    paths: list[str] = []
    for item in inputs:
        if os.path.isdir(item):
            for root, _, files in os.walk(item):
                for name in files:
                    ext = os.path.splitext(name)[1].lower()
                    if ext in IMAGE_EXTS:
                        paths.append(os.path.join(root, name))
        elif os.path.isfile(item):
            ext = os.path.splitext(item)[1].lower()
            if ext in IMAGE_EXTS:
                paths.append(item)
        else:
            for match in glob(item):
                if os.path.isfile(match):
                    ext = os.path.splitext(match)[1].lower()
                    if ext in IMAGE_EXTS:
                        paths.append(match)
    return sorted(list(dict.fromkeys(paths)))


def safe_name(path: str) -> str:
    clean = path.strip(os.sep).replace(os.sep, "__")
    clean = re.sub(r"\s+", "_", clean)
    return clean


def select_image_ckpt(save_path: str, image_ckpt: str | None) -> str:
    if image_ckpt:
        return image_ckpt
    default_ckpt = os.path.join(save_path, "image_adapter.pth")
    if os.path.isfile(default_ckpt):
        return default_ckpt
    candidates = glob(os.path.join(save_path, "image_adapter_*.pth"))
    if not candidates:
        raise FileNotFoundError(
            f"No image adapter checkpoint found under: {save_path}"
        )

    def _epoch(path: str) -> int:
        match = re.search(r"image_adapter_(\d+)\.pth", os.path.basename(path))
        return int(match.group(1)) if match else -1

    candidates = sorted(candidates, key=_epoch)
    return candidates[-1]


def build_text_embeddings(model, class_name: str, device: torch.device) -> torch.Tensor:
    prompt_state = [PROMPTS["prompt_normal"], PROMPTS["prompt_abnormal"]]
    prompt_templates = PROMPTS["prompt_templates"]

    text_features = []
    for state in prompt_state:
        prompted_state = [s.format(class_name) for s in state]
        prompted_sentence = []
        for s in prompted_state:
            for template in prompt_templates:
                prompted_sentence.append(template.format(s))
        prompted_sentence = tokenize(prompted_sentence).to(device)
        class_embeddings = model.encode_text(prompted_sentence)
        class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
        class_embedding = class_embeddings.mean(dim=0)
        class_embedding = class_embedding / class_embedding.norm()
        text_features.append(class_embedding)
    text_features = torch.stack(text_features, dim=1).to(device)
    return text_features


def get_transform(img_size: int) -> transforms.Compose:
    return transforms.Compose(
        [
            transforms.Resize((img_size, img_size), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.48145466, 0.4578275, 0.40821073),
                std=(0.26862954, 0.26130258, 0.27577711),
            ),
        ]
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="AA-CLIP single image prediction")
    parser.add_argument("--input", nargs="+", required=True)
    parser.add_argument("--save_path", type=str, required=True)
    parser.add_argument("--image_ckpt", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default="predict_outputs")
    parser.add_argument("--model_name", type=str, default="ViT-L-14-336")
    parser.add_argument("--img_size", type=int, default=518)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--class_name", type=str, default="object")
    parser.add_argument("--domain", type=str, default="Industrial", choices=["Industrial", "Medical"])
    parser.add_argument("--alpha", type=float, default=0.5)
    parser.add_argument("--save_heatmap", action="store_true")
    parser.add_argument("--relu", action="store_true")
    parser.add_argument("--text_adapt_weight", type=float, default=0.1)
    parser.add_argument("--image_adapt_weight", type=float, default=0.1)
    parser.add_argument("--text_adapt_until", type=int, default=3)
    parser.add_argument("--image_adapt_until", type=int, default=6)
    args = parser.parse_args()

    image_paths = collect_images(args.input)
    if not image_paths:
        raise ValueError("No images found from --input")

    os.makedirs(args.output_dir, exist_ok=True)
    heatmap_dir = os.path.join(args.output_dir, "heatmaps")
    scoremap_dir = os.path.join(args.output_dir, "scoremaps")
    if args.save_heatmap:
        os.makedirs(heatmap_dir, exist_ok=True)
        os.makedirs(scoremap_dir, exist_ok=True)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    clip_model = create_model(
        model_name=args.model_name,
        img_size=args.img_size,
        device=device,
        pretrained="openai",
        require_pretrained=True,
    )
    clip_model.eval()

    model = AdaptedCLIP(
        clip_model=clip_model,
        text_adapt_weight=args.text_adapt_weight,
        image_adapt_weight=args.image_adapt_weight,
        text_adapt_until=args.text_adapt_until,
        image_adapt_until=args.image_adapt_until,
        relu=args.relu,
    ).to(device)
    model.eval()

    text_ckpt = os.path.join(args.save_path, "text_adapter.pth")
    use_text_adapter = False
    if os.path.isfile(text_ckpt):
        checkpoint = torch.load(text_ckpt, map_location=device)
        model.text_adapter.load_state_dict(checkpoint["text_adapter"])
        use_text_adapter = True

    image_ckpt = select_image_ckpt(args.save_path, args.image_ckpt)
    image_checkpoint = torch.load(image_ckpt, map_location=device)
    model.image_adapter.load_state_dict(image_checkpoint["image_adapter"])

    text_model = model if use_text_adapter else clip_model
    text_features = build_text_embeddings(text_model, args.class_name, device)

    transform = get_transform(args.img_size)

    rows = []
    with torch.no_grad():
        for start in tqdm(range(0, len(image_paths), args.batch_size)):
            batch_paths = image_paths[start : start + args.batch_size]
            batch_images = []
            for path in batch_paths:
                img = Image.open(path).convert("RGB")
                batch_images.append(transform(img))
            images = torch.stack(batch_images, dim=0).to(device)

            patch_features, det_feature = model(images)
            det_logits = det_feature @ text_features
            det_scores = (det_logits[:, 1] + 1) / 2

            patch_preds = []
            for f in patch_features:
                patch_pred = calculate_similarity_map(
                    f,
                    text_features,
                    args.img_size,
                    test=True,
                    domain=args.domain,
                )
                patch_preds.append(patch_pred)
            patch_preds = torch.cat(patch_preds, dim=1).sum(1)

            pixel_max = patch_preds.view(patch_preds.shape[0], -1).max(1)[0]
            if args.domain == "Medical":
                final_scores = pixel_max
            else:
                final_scores = 0.5 * pixel_max + 0.5 * det_scores

            for i, path in enumerate(batch_paths):
                score = float(final_scores[i].cpu().item())
                det_score = float(det_scores[i].cpu().item())
                pixel_score = float(pixel_max[i].cpu().item())
                rows.append(
                    {
                        "image_path": path,
                        "score": score,
                        "det_score": det_score,
                        "pixel_score": pixel_score,
                    }
                )

                if args.save_heatmap:
                    scoremap = patch_preds[i].cpu().numpy()
                    scoremap = (scoremap - scoremap.min()) / (
                        scoremap.max() - scoremap.min() + 1e-6
                    )
                    scoremap_uint8 = (scoremap * 255).astype(np.uint8)

                    image = cv2.imread(path)
                    if image is None:
                        continue
                    image = cv2.resize(image, scoremap_uint8.shape[::-1])
                    overlay = apply_ad_scoremap(image, scoremap_uint8, alpha=args.alpha)

                    name = safe_name(path) + ".png"
                    cv2.imwrite(os.path.join(heatmap_dir, name), overlay)
                    cv2.imwrite(os.path.join(scoremap_dir, name), scoremap_uint8)

    csv_path = os.path.join(args.output_dir, "predictions.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f, fieldnames=["image_path", "score", "det_score", "pixel_score"]
        )
        writer.writeheader()
        writer.writerows(rows)

    print(f"Done. Wrote {len(rows)} predictions to {csv_path}")


if __name__ == "__main__":
    main()

使用方式,以InsPLAD数据集为例子

bash 复制代码
python predict.py \
  --input /root/data1/InsPLAD/defect_supervised/glass-insulator/val \
  --save_path /root/data1/AA-CLIP/ckpt/visa_full \
  --class_name "glass insulator" \
  --domain Industrial \
  --save_heatmap
bash 复制代码
python predict.py \
  --input /root/data1/InsPLAD/defect_unsupervised/glass-insulator/test \
  --save_path /root/data1/AA-CLIP/ckpt/visa_full \
  --class_name "glass insulator" \
  --domain Industrial \
  --save_heatmap

结果会写到 predict_outputs/predictions.csv(每张图的 score、det_score、pixel_score)。

--save_heatmap 会在 predict_outputs/heatmaps 和 predict_outputs/scoremaps 里生成可视化结果。

可调参数

--class_name 就是提示词里的"物体名称",不需要在 constants 里注册;你可以换成 yoke suspension、vari grip 等更贴近实际的描述。

--image_ckpt 指定某个 image_adapter_*.pth,不传则自动选最新

相关推荐
TMT星球15 小时前
他用WPS笔记,把AI报错变成了可复用的“避坑指南”
笔记·wps
lcj251115 小时前
【list】手撕C++ list!从0到1实现双向链表,迭代器、const迭代器、模板全解析,面试官都惊呆了!
c++·笔记·链表·list
niaiheni15 小时前
MySQL JDBC 不出网攻击 → Spring 临时文件利用:完整攻击链复现笔记
笔记·mysql·spring
kgduu16 小时前
cosmos学习笔记
笔记·学习
05候补工程师16 小时前
【408 数据结构】图论核心算法(拓扑/关键路径)与二叉搜索树精髓夺分笔记
数据结构·经验分享·笔记·考研·算法·图论
烛之武16 小时前
《深度学习基础与概念》笔记(2)
人工智能·笔记·深度学习
whyTeaFo16 小时前
MIT 6.1810: xv6 book Chapter6: Interrupts and device drivers 笔记
笔记
map1e_zjc17 小时前
Redis入门笔记(2)
数据库·redis·笔记
@zulnger17 小时前
WebDriver API及对象识别技术
笔记·python·selenium
05候补工程师17 小时前
【期末/408冲刺】软件工程核心考点与大题通关秘籍(附图解与解题套路)
大数据·hadoop·经验分享·笔记·软件工程