AA-Clip复现笔记

创建一个predict.py文件

python 复制代码
import argparse
import csv
import os
import re
from glob import glob

import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from torchvision import transforms

from dataset.constants import PROMPTS
from forward_utils import apply_ad_scoremap, calculate_similarity_map
from model.adapter import AdaptedCLIP
from model.clip import create_model
from model.tokenizer import tokenize


IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}


def collect_images(inputs: list[str]) -> list[str]:
    paths: list[str] = []
    for item in inputs:
        if os.path.isdir(item):
            for root, _, files in os.walk(item):
                for name in files:
                    ext = os.path.splitext(name)[1].lower()
                    if ext in IMAGE_EXTS:
                        paths.append(os.path.join(root, name))
        elif os.path.isfile(item):
            ext = os.path.splitext(item)[1].lower()
            if ext in IMAGE_EXTS:
                paths.append(item)
        else:
            for match in glob(item):
                if os.path.isfile(match):
                    ext = os.path.splitext(match)[1].lower()
                    if ext in IMAGE_EXTS:
                        paths.append(match)
    return sorted(list(dict.fromkeys(paths)))


def safe_name(path: str) -> str:
    clean = path.strip(os.sep).replace(os.sep, "__")
    clean = re.sub(r"\s+", "_", clean)
    return clean


def select_image_ckpt(save_path: str, image_ckpt: str | None) -> str:
    if image_ckpt:
        return image_ckpt
    default_ckpt = os.path.join(save_path, "image_adapter.pth")
    if os.path.isfile(default_ckpt):
        return default_ckpt
    candidates = glob(os.path.join(save_path, "image_adapter_*.pth"))
    if not candidates:
        raise FileNotFoundError(
            f"No image adapter checkpoint found under: {save_path}"
        )

    def _epoch(path: str) -> int:
        match = re.search(r"image_adapter_(\d+)\.pth", os.path.basename(path))
        return int(match.group(1)) if match else -1

    candidates = sorted(candidates, key=_epoch)
    return candidates[-1]


def build_text_embeddings(model, class_name: str, device: torch.device) -> torch.Tensor:
    prompt_state = [PROMPTS["prompt_normal"], PROMPTS["prompt_abnormal"]]
    prompt_templates = PROMPTS["prompt_templates"]

    text_features = []
    for state in prompt_state:
        prompted_state = [s.format(class_name) for s in state]
        prompted_sentence = []
        for s in prompted_state:
            for template in prompt_templates:
                prompted_sentence.append(template.format(s))
        prompted_sentence = tokenize(prompted_sentence).to(device)
        class_embeddings = model.encode_text(prompted_sentence)
        class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
        class_embedding = class_embeddings.mean(dim=0)
        class_embedding = class_embedding / class_embedding.norm()
        text_features.append(class_embedding)
    text_features = torch.stack(text_features, dim=1).to(device)
    return text_features


def get_transform(img_size: int) -> transforms.Compose:
    return transforms.Compose(
        [
            transforms.Resize((img_size, img_size), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.48145466, 0.4578275, 0.40821073),
                std=(0.26862954, 0.26130258, 0.27577711),
            ),
        ]
    )


def main() -> None:
    parser = argparse.ArgumentParser(description="AA-CLIP single image prediction")
    parser.add_argument("--input", nargs="+", required=True)
    parser.add_argument("--save_path", type=str, required=True)
    parser.add_argument("--image_ckpt", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default="predict_outputs")
    parser.add_argument("--model_name", type=str, default="ViT-L-14-336")
    parser.add_argument("--img_size", type=int, default=518)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--class_name", type=str, default="object")
    parser.add_argument("--domain", type=str, default="Industrial", choices=["Industrial", "Medical"])
    parser.add_argument("--alpha", type=float, default=0.5)
    parser.add_argument("--save_heatmap", action="store_true")
    parser.add_argument("--relu", action="store_true")
    parser.add_argument("--text_adapt_weight", type=float, default=0.1)
    parser.add_argument("--image_adapt_weight", type=float, default=0.1)
    parser.add_argument("--text_adapt_until", type=int, default=3)
    parser.add_argument("--image_adapt_until", type=int, default=6)
    args = parser.parse_args()

    image_paths = collect_images(args.input)
    if not image_paths:
        raise ValueError("No images found from --input")

    os.makedirs(args.output_dir, exist_ok=True)
    heatmap_dir = os.path.join(args.output_dir, "heatmaps")
    scoremap_dir = os.path.join(args.output_dir, "scoremaps")
    if args.save_heatmap:
        os.makedirs(heatmap_dir, exist_ok=True)
        os.makedirs(scoremap_dir, exist_ok=True)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")

    clip_model = create_model(
        model_name=args.model_name,
        img_size=args.img_size,
        device=device,
        pretrained="openai",
        require_pretrained=True,
    )
    clip_model.eval()

    model = AdaptedCLIP(
        clip_model=clip_model,
        text_adapt_weight=args.text_adapt_weight,
        image_adapt_weight=args.image_adapt_weight,
        text_adapt_until=args.text_adapt_until,
        image_adapt_until=args.image_adapt_until,
        relu=args.relu,
    ).to(device)
    model.eval()

    text_ckpt = os.path.join(args.save_path, "text_adapter.pth")
    use_text_adapter = False
    if os.path.isfile(text_ckpt):
        checkpoint = torch.load(text_ckpt, map_location=device)
        model.text_adapter.load_state_dict(checkpoint["text_adapter"])
        use_text_adapter = True

    image_ckpt = select_image_ckpt(args.save_path, args.image_ckpt)
    image_checkpoint = torch.load(image_ckpt, map_location=device)
    model.image_adapter.load_state_dict(image_checkpoint["image_adapter"])

    text_model = model if use_text_adapter else clip_model
    text_features = build_text_embeddings(text_model, args.class_name, device)

    transform = get_transform(args.img_size)

    rows = []
    with torch.no_grad():
        for start in tqdm(range(0, len(image_paths), args.batch_size)):
            batch_paths = image_paths[start : start + args.batch_size]
            batch_images = []
            for path in batch_paths:
                img = Image.open(path).convert("RGB")
                batch_images.append(transform(img))
            images = torch.stack(batch_images, dim=0).to(device)

            patch_features, det_feature = model(images)
            det_logits = det_feature @ text_features
            det_scores = (det_logits[:, 1] + 1) / 2

            patch_preds = []
            for f in patch_features:
                patch_pred = calculate_similarity_map(
                    f,
                    text_features,
                    args.img_size,
                    test=True,
                    domain=args.domain,
                )
                patch_preds.append(patch_pred)
            patch_preds = torch.cat(patch_preds, dim=1).sum(1)

            pixel_max = patch_preds.view(patch_preds.shape[0], -1).max(1)[0]
            if args.domain == "Medical":
                final_scores = pixel_max
            else:
                final_scores = 0.5 * pixel_max + 0.5 * det_scores

            for i, path in enumerate(batch_paths):
                score = float(final_scores[i].cpu().item())
                det_score = float(det_scores[i].cpu().item())
                pixel_score = float(pixel_max[i].cpu().item())
                rows.append(
                    {
                        "image_path": path,
                        "score": score,
                        "det_score": det_score,
                        "pixel_score": pixel_score,
                    }
                )

                if args.save_heatmap:
                    scoremap = patch_preds[i].cpu().numpy()
                    scoremap = (scoremap - scoremap.min()) / (
                        scoremap.max() - scoremap.min() + 1e-6
                    )
                    scoremap_uint8 = (scoremap * 255).astype(np.uint8)

                    image = cv2.imread(path)
                    if image is None:
                        continue
                    image = cv2.resize(image, scoremap_uint8.shape[::-1])
                    overlay = apply_ad_scoremap(image, scoremap_uint8, alpha=args.alpha)

                    name = safe_name(path) + ".png"
                    cv2.imwrite(os.path.join(heatmap_dir, name), overlay)
                    cv2.imwrite(os.path.join(scoremap_dir, name), scoremap_uint8)

    csv_path = os.path.join(args.output_dir, "predictions.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.DictWriter(
            f, fieldnames=["image_path", "score", "det_score", "pixel_score"]
        )
        writer.writeheader()
        writer.writerows(rows)

    print(f"Done. Wrote {len(rows)} predictions to {csv_path}")


if __name__ == "__main__":
    main()

使用方式,以InsPLAD数据集为例子

bash 复制代码
python predict.py \
  --input /root/data1/InsPLAD/defect_supervised/glass-insulator/val \
  --save_path /root/data1/AA-CLIP/ckpt/visa_full \
  --class_name "glass insulator" \
  --domain Industrial \
  --save_heatmap
bash 复制代码
python predict.py \
  --input /root/data1/InsPLAD/defect_unsupervised/glass-insulator/test \
  --save_path /root/data1/AA-CLIP/ckpt/visa_full \
  --class_name "glass insulator" \
  --domain Industrial \
  --save_heatmap

结果会写到 predict_outputs/predictions.csv(每张图的 score、det_score、pixel_score)。

--save_heatmap 会在 predict_outputs/heatmaps 和 predict_outputs/scoremaps 里生成可视化结果。

可调参数

--class_name 就是提示词里的"物体名称",不需要在 constants 里注册;你可以换成 yoke suspension、vari grip 等更贴近实际的描述。

--image_ckpt 指定某个 image_adapter_*.pth,不传则自动选最新

相关推荐
LinXunFeng6 天前
Obsidian - 使用 Share Note 分享笔记并自部署
前端·笔记·github
闪闪发亮的小星星10 天前
高斯光以及高斯光公式解释
笔记
cqbzcsq10 天前
CellFlow虚拟细胞论文阅读
论文阅读·人工智能·笔记·学习·生物信息
阿米亚波10 天前
【Windows】QEMU 启动 openEuler aarch64/arm64 架构系统 + 离线软件源
linux·windows·经验分享·笔记·架构·arm
自传.10 天前
尚硅谷 Vibe Coding|第三章(1) Claude Code深度使用与进阶技巧 学习笔记
笔记·学习·尚硅谷·vibecoding
.千余10 天前
【C++】模板进阶全解:非类型参数|全特化|偏特化|分离编译完全指南
开发语言·c++·笔记·学习·其他
自传.10 天前
尚硅谷 Vibe Coding|第二章 AI编程工具生态 学习笔记
笔记·学习·ai编程·尚硅谷·vibe coding
秋波。未央10 天前
Java Agent 开发 · Day 1 学习笔记(含作业完整标准答案)
java·笔记·学习
中屹指纹浏览器10 天前
2026指纹浏览器字体指纹、字体渲染偏差检测与全维度虚拟字体池搭建方案
经验分享·笔记