使用数据库sqlite 筛选人脸信息

主要筛选人脸信息(比如:0 这个人的文件夹里有很多张属于0的人脸照片,但是同时又参杂一些非常模糊或者其他人的照片,那么可以通过这个方法把参杂的模糊的和其他人的人脸排序到最后,那样清理的时候就不需要到处找那些不合格的照片)

import os
import shutil

import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from PIL import Image
import torch
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1
import sqlite3
import threading

# 1. 加载预训练的人脸特征提取模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)

# 2. 图像预处理
transform = transforms.Compose([
    transforms.Resize((160, 160)),  # FaceNet 输入尺寸
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# 3. 提取单张图像的特征向量
def extract_feature(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feature = model(image).cpu().numpy().flatten()
    return feature


# 4. 创建 SQLite 数据库
def create_database(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS features (
            person_id TEXT,
            image_path TEXT,
            feature_vector BLOB,
            PRIMARY KEY (person_id, image_path)
        )
    ''')
    conn.commit()
    conn.close()


# 5. 将特征向量保存到数据库
def save_feature_to_db(db_path, person_id, image_path, feature):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 检查是否有相同的person_id 和 image_path 存在 (目的是为例防止程序中断 后 又重新运行 数据插入冲突导致报错)
    cursor.execute("""
       SELECT COUNT(*) FROM features
       WHERE person_id = ? AND image_path = ? """, (person_id, image_path))
    count = cursor.fetchone()[0]

    # 如果不存在
    if count == 0:
        feature_blob = feature.tobytes()  # 将特征向量转换为二进制格式
        cursor.execute('''
            INSERT INTO features (person_id, image_path, feature_vector)
            VALUES (?, ?, ?)
        ''', (person_id, image_path, feature_blob))
        conn.commit()
        conn.close()
    else:
        print(f"Feature for {person_id} - {image_path} already exists,  skipping")


# 6. 处理每个文件夹,提取特征并保存到数据库
def process_folder(db_path, folder_path, person_id):
    for image_name in os.listdir(folder_path):
        image_path = os.path.join(folder_path, image_name)
        # 避免处理非图片文件
        if image_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            # 防止因图片损坏导致提取特侦失败致使程序中断
            try:
                feature = extract_feature(image_path)
                save_feature_to_db(db_path, person_id, image_path, feature)
            except Exception as e:
                print(e)


# 7. 从数据库中获取某个人的平均特征向量
def get_avg_feature(db_path, person_id):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 将所有特征向量转换为 numpy 数组
    features = [np.frombuffer(row[0], dtype=np.float32) for row in rows]
    avg_feature = np.mean(features, axis=0)
    return avg_feature


# 8. 根据欧氏距离排序并重命名图像
def sort_and_rename_images(db_path, out_path, person_id):
    avg_feature = get_avg_feature(db_path, person_id)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT image_path, feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 计算欧氏距离并排序
    distances = []
    for row in rows:
        image_path, feature_blob = row
        feature = np.frombuffer(feature_blob, dtype=np.float32)
        distance = euclidean_distances([feature], [avg_feature])[0][0]
        distances.append((image_path, distance))

    # 按距离排序
    distances.sort(key=lambda x: x[1])

    # 重命名文件
    for idx, (image_path, _) in enumerate(distances):
        new_name = f"{idx:04d}.jpg"  # 按距离排序后的新文件名
        # new_path = os.path.join(folder_path, new_name)
        new_path = rf'{out_path}/{person_id}/{new_name}'
        # 如果目标文件夹不存在,则创建
        os.makedirs(os.path.dirname(new_path), exist_ok=True)
        shutil.copy(image_path, new_path)

        # os.rename(image_path, new_path)


# 9. 主函数
def main():
    # 数据库路径
    db_path = r'D:\FS_project2\Feature_extraction\sql_database\features.db2'
    create_database(db_path)

    # 基础路径
    base_path = r'D:\FS_project2\Feature_extraction\peopel_crop'
    out_path = r'D:\FS_project2\Feature_extraction\out'

    # 第一步:提取特征并保存到数据库
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            process_folder(db_path, folder_path, folder)
            print(f"Processed folder: {folder}")

    # 第二步:排序并重命名图像
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            sort_and_rename_images(db_path, out_path, folder)
            print(f"Sorted and renamed folder: {folder}")


if __name__ == "__main__":
    main()
相关推荐
库库林_沙琪马1 小时前
Redis 持久化:从零到掌握
数据库·redis·缓存
牵牛老人3 小时前
Qt中使用QPdfWriter类结合QPainter类绘制并输出PDF文件
数据库·qt·pdf
卡西里弗斯奥4 小时前
【达梦数据库】dblink连接[SqlServer/Mysql]报错处理
数据库·mysql·sqlserver·达梦
m0_748255414 小时前
vscode配置django环境并创建django项目(全图文操作)
vscode·django·sqlite
温柔小胖5 小时前
sql注入之python脚本进行时间盲注和布尔盲注
数据库·sql·网络安全
杨俊杰-YJ5 小时前
MySQL 主从复制原理及其工作过程
数据库·mysql
一个儒雅随和的男子6 小时前
MySQL的聚簇索引与非聚簇索引
数据库·mysql
V+zmm101348 小时前
基于微信小程序的家政服务预约系统的设计与实现(php论文源码调试讲解)
java·数据库·微信小程序·小程序·毕业设计
roman_日积跬步-终至千里8 小时前
【分布式理论14】分布式数据库存储:分表分库、主从复制与数据扩容策略
数据库·分布式
hadage2338 小时前
--- Mysql事务 ---
数据库·mysql