深度学习之使用Milvus向量数据库实战图搜图

python 复制代码
import torch
from torchvision import models,transforms
from torch.utils.data import Dataset , DataLoader
import os
import pickle
from PIL import Image
from tqdm import tqdm
from pymilvus import (
FieldSchema,
DataType,
db,
connections,
CollectionSchema,
Collection
)
import time
import matplotlib.pyplot as plt
python 复制代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
python 复制代码
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
python 复制代码
image_dir = "./flower_data/train"
image_dirs = [f"{p}/{n}" for p , n in zip([image_dir] * 102 , os.listdir(image_dir))]
python 复制代码
image_paths = []
for dir in image_dirs:
    names = os.listdir(dir)
    for name in names:
        image_paths.append(os.path.join(dir,name))
python 复制代码
image_paths
python 复制代码
image_dirs
python 复制代码
with open("image_paths.pkl" , "wb" ) as fw:
    pickle.dump(image_paths, fw)
python 复制代码
class ImageDataset(Dataset):
    def __init__(self , transform =None):
        super().__init__()
        self.transform = transform
        with open("./image_paths.pkl", "rb") as fr:
            self.data_paths = pickle.load(fr)
            
        self.data = []
        
        for image_path in self.data_paths:
            img = Image.open(image_path)
            if img.mode == "RGB":
                self.data.append(image_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index): 
        image_path = self.data[index]
        img = Image.open(image_path)
        
        if self.transform:
            img = self.transform(img)
        
        dict_data = {
            "idx" : index,
            "image_path" : image_path,
            "img" : img
        }
        return dict_data
python 复制代码
valid_dataset = ImageDataset(transform=transform)
python 复制代码
len(valid_dataset)
python 复制代码
valid_dataloader = DataLoader(valid_dataset , batch_size=64, shuffle=False)
python 复制代码
def load_model():
    model = models.resnet18(pretrained = True)
    model.to(device)
    model.eval()
    return model
python 复制代码
model = load_model()
python 复制代码
model
复制代码
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
python 复制代码
def feature_extract(model, x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    x = model.avgpool(x)
    x = torch.flatten(x, 1)
    return x
python 复制代码
feature_list = []
feature_index_list = []
feature_image_path_list = []
for idx , batch in enumerate(tqdm(valid_dataloader)):
    imgs = batch["img"]
    indexs = batch["idx"]
    image_paths = batch["image_path"]
    img = imgs.to(device)
    feature = feature_extract(model, img)
    feature = feature.data.cpu().numpy()
    feature_list.extend(feature)
    feature_index_list.extend(indexs)
    feature_image_path_list.extend(image_paths)
python 复制代码
entities = [
    feature_image_path_list,
    feature_list 
]
python 复制代码
len(feature_list)
python 复制代码
entities[0]
python 复制代码
fields = [
    FieldSchema(name="image_path" ,dtype=DataType.VARCHAR, description="图片路径", max_length = 512 , is_primary=True, auto_id=False),
    FieldSchema(name="embeddings" , dtype=DataType.FLOAT_VECTOR,description="向量表示图片" , is_primary=False,dim=512)
]
schema = CollectionSchema(fields,description="用于图生图的表")
python 复制代码
connections.connect("power_image_search",host="ljxwtl.cn",port=19530,db_name="power_image_search")
python 复制代码
table = Collection("image_to_image", schema=schema,consistency_level="Strong",using="power_image_search")
python 复制代码
for idx , image_path in enumerate(feature_image_path_list):
    entity = [
        [feature_image_path_list[idx]],
        [feature_list[idx]]
    ]
    table.insert(entity)
python 复制代码
table.flush()
python 复制代码
table.num_entities

6552

python 复制代码
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
python 复制代码
table.create_index("embeddings",index_params=index)
python 复制代码
table.load()
python 复制代码
vectors_to_search = entities[-1][1:2]
search_params = {
    "metric_type": "L2",
    "params": {"nprobe": 10},
}
python 复制代码
start_time = time.time()
result = table.search(vectors_to_search, "embeddings", search_params, limit=5, output_fields=["image_path"])
end_time = time.time()
python 复制代码
for hits in result:
    for hit in hits:
        print(f"hit: {hit}, image_path field: {hit.entity.get('image_path')}")
python 复制代码
img_data = plt.imread(entities[0][1])
plt.imshow(img_data)
plt.show()
python 复制代码
img_data = plt.imread("./flower_data/train/1\\image_06766.jpg")
plt.imshow(img_data)
plt.show()
相关推荐
TL滕7 分钟前
Datawhale AI冬令营 动手学AI Agent
人工智能·笔记·学习·aigc
吕小明么10 分钟前
精读DeepSeek v3技术文档的心得感悟
人工智能·神经网络·aigc·agi
hunteritself42 分钟前
再谈ChatGPT降智:已蔓延到全端,附解决方案!
人工智能·gpt·算法·机器学习·chatgpt·openai
Trouvaille ~42 分钟前
【机器学习】分而知变,积而见道:微积分中的世界之思
人工智能·python·机器学习·ai·数据分析·微积分·梯度下降
田梓燊42 分钟前
机器学习基本概念,基本步骤,分类,简单理解,线性模型
人工智能·机器学习·分类
morning_judger43 分钟前
【AI大模型系列】prompt提示词(二)
人工智能·prompt
AIGC大时代2 小时前
如何判断一个学术论文是否具有真正的科研价值?ChatGPT如何提供帮助?
大数据·人工智能·物联网·chatgpt·aigc
岁月如歌,青春不败3 小时前
HMSC联合物种分布模型
开发语言·人工智能·python·深度学习·r语言
海域云赵从友3 小时前
香港 GPU 服务器托管引领 AI 创新,助力 AI 发展
运维·服务器·人工智能
四口鲸鱼爱吃盐4 小时前
Pytorch | 利用GRA针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·计算机视觉