python
import torch
from torchvision import models,transforms
from torch.utils.data import Dataset , DataLoader
import os
import pickle
from PIL import Image
from tqdm import tqdm
from pymilvus import (
FieldSchema,
DataType,
db,
connections,
CollectionSchema,
Collection
)
import time
import matplotlib.pyplot as plt
python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
python
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
python
image_dir = "./flower_data/train"
image_dirs = [f"{p}/{n}" for p , n in zip([image_dir] * 102 , os.listdir(image_dir))]
python
image_paths = []
for dir in image_dirs:
names = os.listdir(dir)
for name in names:
image_paths.append(os.path.join(dir,name))
python
image_paths
python
image_dirs
python
with open("image_paths.pkl" , "wb" ) as fw:
pickle.dump(image_paths, fw)
python
class ImageDataset(Dataset):
def __init__(self , transform =None):
super().__init__()
self.transform = transform
with open("./image_paths.pkl", "rb") as fr:
self.data_paths = pickle.load(fr)
self.data = []
for image_path in self.data_paths:
img = Image.open(image_path)
if img.mode == "RGB":
self.data.append(image_path)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image_path = self.data[index]
img = Image.open(image_path)
if self.transform:
img = self.transform(img)
dict_data = {
"idx" : index,
"image_path" : image_path,
"img" : img
}
return dict_data
python
valid_dataset = ImageDataset(transform=transform)
python
len(valid_dataset)
python
valid_dataloader = DataLoader(valid_dataset , batch_size=64, shuffle=False)
python
def load_model():
model = models.resnet18(pretrained = True)
model.to(device)
model.eval()
return model
python
model = load_model()
python
model
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
python
def feature_extract(model, x):
x = model.conv1(x)
x = model.bn1(x)
x = model.relu(x)
x = model.maxpool(x)
x = model.layer1(x)
x = model.layer2(x)
x = model.layer3(x)
x = model.layer4(x)
x = model.avgpool(x)
x = torch.flatten(x, 1)
return x
python
feature_list = []
feature_index_list = []
feature_image_path_list = []
for idx , batch in enumerate(tqdm(valid_dataloader)):
imgs = batch["img"]
indexs = batch["idx"]
image_paths = batch["image_path"]
img = imgs.to(device)
feature = feature_extract(model, img)
feature = feature.data.cpu().numpy()
feature_list.extend(feature)
feature_index_list.extend(indexs)
feature_image_path_list.extend(image_paths)
python
entities = [
feature_image_path_list,
feature_list
]
python
len(feature_list)
python
entities[0]
python
fields = [
FieldSchema(name="image_path" ,dtype=DataType.VARCHAR, description="图片路径", max_length = 512 , is_primary=True, auto_id=False),
FieldSchema(name="embeddings" , dtype=DataType.FLOAT_VECTOR,description="向量表示图片" , is_primary=False,dim=512)
]
schema = CollectionSchema(fields,description="用于图生图的表")
python
connections.connect("power_image_search",host="ljxwtl.cn",port=19530,db_name="power_image_search")
python
table = Collection("image_to_image", schema=schema,consistency_level="Strong",using="power_image_search")
python
for idx , image_path in enumerate(feature_image_path_list):
entity = [
[feature_image_path_list[idx]],
[feature_list[idx]]
]
table.insert(entity)
python
table.flush()
python
table.num_entities
6552
python
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}
python
table.create_index("embeddings",index_params=index)
python
table.load()
python
vectors_to_search = entities[-1][1:2]
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
python
start_time = time.time()
result = table.search(vectors_to_search, "embeddings", search_params, limit=5, output_fields=["image_path"])
end_time = time.time()
python
for hits in result:
for hit in hits:
print(f"hit: {hit}, image_path field: {hit.entity.get('image_path')}")
python
img_data = plt.imread(entities[0][1])
plt.imshow(img_data)
plt.show()
python
img_data = plt.imread("./flower_data/train/1\\image_06766.jpg")
plt.imshow(img_data)
plt.show()