目录
全是干货
超参数设置
python
# 超参数设置
PARAMS = {
"clip_model": "openai/clip-vit-base-patch32", # 推理模型名称
"video_folder": "./video_test", # 视频文件夹路径
"text_description": "A photo of a person wearing pink clothes", # 文本特征描述
"frame_extraction_interval": 10, # 每隔多少秒提取一帧
"save_frames_dir": "saved_frames", # 保存匹配帧的目录
}
配置LLM-clip的backbone
python
from transformers import CLIPProcessor, CLIPModel
import torch
# 加载CLIP模型和处理器
model = CLIPModel.from_pretrained(PARAMS["clip_model"])
processor = CLIPProcessor.from_pretrained(PARAMS["clip_model"])
文本编码
python
text_inputs = processor(text=[PARAMS["text_description"]], return_tensors="pt", padding=True)
with torch.no_grad():
text_embedding = model.get_text_features(**text_inputs)
抽取视频帧并编码
python
from moviepy.editor import VideoFileClip
import numpy as np
def extract_frames(video_path, interval):
clip = VideoFileClip(video_path)
frame_times = np.arange(0, int(clip.duration), interval)
frames = [clip.get_frame(t) for t in frame_times]
return frames, frame_times
def get_frame_embeddings(frames):
frame_embeddings = []
for frame in frames:
frame_inputs = processor(images=frame, return_tensors="pt", padding=True)
with torch.no_grad():
# 使用 get_image_features 方法来获取图像特征
frame_outputs = model.get_image_features(**frame_inputs)
# 直接使用 frame_outputs,不需要访问 pooler_output
frame_embeddings.append(frame_outputs)
return torch.vstack(frame_embeddings)
视频帧匹配
python
def find_best_matching_frames(video_path, text_embedding, interval, top_k=10):
frames, frame_times = extract_frames(video_path, interval)
frame_embeddings = get_frame_embeddings(frames)
similarities = torch.nn.functional.cosine_similarity(text_embedding, frame_embeddings.squeeze())
top_k_values, top_k_indices = similarities.topk(top_k)
return [(frames[idx], frame_times[idx]) for idx in top_k_indices.cpu().numpy()]
保存结果帧
python
from PIL import Image
import os
def save_frames(frames_info, video_path, save_dir):
# 从视频路径中提取视频名称作为文件夹名称
video_name = os.path.basename(video_path).split('.')[0]
save_path = os.path.join(save_dir, video_name)
os.makedirs(save_path, exist_ok=True)
for i, (frame, time) in enumerate(frames_info):
frame_image = Image.fromarray(frame)
frame_image.save(os.path.join(save_path, f"frame_at_{time:.2f}s_{i+1}.png"))
工程流
python
def main(PARAMS):
video_paths = [os.path.join(PARAMS["video_folder"], f) for f in os.listdir(PARAMS["video_folder"]) if f.endswith('.mp4')]
# 确定保存图像的目录
save_dir = PARAMS.get("save_frames_dir", "./saved_frames")
os.makedirs(save_dir, exist_ok=True)
for video_path in video_paths:
best_frames_info = find_best_matching_frames(video_path, text_embedding, PARAMS["frame_extraction_interval"], top_k=10)
save_frames(best_frames_info, video_path, save_dir)
print(f"Saved top 10 frames for video {video_path} in {save_dir}.")
if __name__ == "__main__":
main(PARAMS)