深度学习之使用Milvus向量数据库实战图搜图

python 复制代码
import torch
from torchvision import models,transforms
from torch.utils.data import Dataset , DataLoader
import os
import pickle
from PIL import Image
from tqdm import tqdm
from pymilvus import (
FieldSchema,
DataType,
db,
connections,
CollectionSchema,
Collection
)
import time
import matplotlib.pyplot as plt
python 复制代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
python 复制代码
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
python 复制代码
image_dir = "./flower_data/train"
image_dirs = [f"{p}/{n}" for p , n in zip([image_dir] * 102 , os.listdir(image_dir))]
python 复制代码
image_paths = []
for dir in image_dirs:
    names = os.listdir(dir)
    for name in names:
        image_paths.append(os.path.join(dir,name))
python 复制代码
image_paths
python 复制代码
image_dirs
python 复制代码
with open("image_paths.pkl" , "wb" ) as fw:
    pickle.dump(image_paths, fw)
python 复制代码
class ImageDataset(Dataset):
    def __init__(self , transform =None):
        super().__init__()
        self.transform = transform
        with open("./image_paths.pkl", "rb") as fr:
            self.data_paths = pickle.load(fr)
            
        self.data = []
        
        for image_path in self.data_paths:
            img = Image.open(image_path)
            if img.mode == "RGB":
                self.data.append(image_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index): 
        image_path = self.data[index]
        img = Image.open(image_path)
        
        if self.transform:
            img = self.transform(img)
        
        dict_data = {
            "idx" : index,
            "image_path" : image_path,
            "img" : img
        }
        return dict_data
python 复制代码
valid_dataset = ImageDataset(transform=transform)
python 复制代码
len(valid_dataset)
python 复制代码
valid_dataloader = DataLoader(valid_dataset , batch_size=64, shuffle=False)
python 复制代码
def load_model():
    model = models.resnet18(pretrained = True)
    model.to(device)
    model.eval()
    return model
python 复制代码
model = load_model()
python 复制代码
model
复制代码
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
python 复制代码
def feature_extract(model, x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    x = model.avgpool(x)
    x = torch.flatten(x, 1)
    return x
python 复制代码
feature_list = []
feature_index_list = []
feature_image_path_list = []
for idx , batch in enumerate(tqdm(valid_dataloader)):
    imgs = batch["img"]
    indexs = batch["idx"]
    image_paths = batch["image_path"]
    img = imgs.to(device)
    feature = feature_extract(model, img)
    feature = feature.data.cpu().numpy()
    feature_list.extend(feature)
    feature_index_list.extend(indexs)
    feature_image_path_list.extend(image_paths)
python 复制代码
entities = [
    feature_image_path_list,
    feature_list 
]
python 复制代码
len(feature_list)
python 复制代码
entities[0]
python 复制代码
fields = [
    FieldSchema(name="image_path" ,dtype=DataType.VARCHAR, description="图片路径", max_length = 512 , is_primary=True, auto_id=False),
    FieldSchema(name="embeddings" , dtype=DataType.FLOAT_VECTOR,description="向量表示图片" , is_primary=False,dim=512)
]
schema = CollectionSchema(fields,description="用于图生图的表")
python 复制代码
connections.connect("power_image_search",host="ljxwtl.cn",port=19530,db_name="power_image_search")
python 复制代码
table = Collection("image_to_image", schema=schema,consistency_level="Strong",using="power_image_search")
python 复制代码
for idx , image_path in enumerate(feature_image_path_list):
    entity = [
        [feature_image_path_list[idx]],
        [feature_list[idx]]
    ]
    table.insert(entity)
python 复制代码
table.flush()
python 复制代码
table.num_entities

6552

python 复制代码
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
python 复制代码
table.create_index("embeddings",index_params=index)
python 复制代码
table.load()
python 复制代码
vectors_to_search = entities[-1][1:2]
search_params = {
    "metric_type": "L2",
    "params": {"nprobe": 10},
}
python 复制代码
start_time = time.time()
result = table.search(vectors_to_search, "embeddings", search_params, limit=5, output_fields=["image_path"])
end_time = time.time()
python 复制代码
for hits in result:
    for hit in hits:
        print(f"hit: {hit}, image_path field: {hit.entity.get('image_path')}")
python 复制代码
img_data = plt.imread(entities[0][1])
plt.imshow(img_data)
plt.show()
python 复制代码
img_data = plt.imread("./flower_data/train/1\\image_06766.jpg")
plt.imshow(img_data)
plt.show()
相关推荐
萱仔学习自我记录2 小时前
PEFT库和transformers库在NLP大模型中的使用和常用方法详解
人工智能·机器学习
BulingQAQ4 小时前
论文阅读:PET/CT Cross-modal medical image fusion of lung tumors based on DCIF-GAN
论文阅读·深度学习·生成对抗网络·计算机视觉·gan
hsling松子4 小时前
使用PaddleHub智能生成,献上浓情国庆福
人工智能·算法·机器学习·语言模型·paddlepaddle
正在走向自律5 小时前
机器学习框架
人工智能·机器学习
好吃番茄5 小时前
U mamba配置问题;‘KeyError: ‘file_ending‘
人工智能·机器学习
CV-King6 小时前
opencv实战项目(三十):使用傅里叶变换进行图像边缘检测
人工智能·opencv·算法·计算机视觉
禁默6 小时前
2024年计算机视觉与艺术研讨会(CVA 2024)
人工智能·计算机视觉
slomay7 小时前
关于对比学习(简单整理
经验分享·深度学习·学习·机器学习
whaosoft-1437 小时前
大模型~合集3
人工智能
Dream-Y.ocean7 小时前
文心智能体平台AgenBuilder | 搭建智能体:情感顾问叶晴
人工智能·智能体