使用数据库sqlite 筛选人脸信息

主要筛选人脸信息(比如:0 这个人的文件夹里有很多张属于0的人脸照片,但是同时又参杂一些非常模糊或者其他人的照片,那么可以通过这个方法把参杂的模糊的和其他人的人脸排序到最后,那样清理的时候就不需要到处找那些不合格的照片)

复制代码
import os
import shutil

import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from PIL import Image
import torch
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1
import sqlite3
import threading

# 1. 加载预训练的人脸特征提取模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)

# 2. 图像预处理
transform = transforms.Compose([
    transforms.Resize((160, 160)),  # FaceNet 输入尺寸
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# 3. 提取单张图像的特征向量
def extract_feature(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feature = model(image).cpu().numpy().flatten()
    return feature


# 4. 创建 SQLite 数据库
def create_database(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS features (
            person_id TEXT,
            image_path TEXT,
            feature_vector BLOB,
            PRIMARY KEY (person_id, image_path)
        )
    ''')
    conn.commit()
    conn.close()


# 5. 将特征向量保存到数据库
def save_feature_to_db(db_path, person_id, image_path, feature):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 检查是否有相同的person_id 和 image_path 存在 (目的是为例防止程序中断 后 又重新运行 数据插入冲突导致报错)
    cursor.execute("""
       SELECT COUNT(*) FROM features
       WHERE person_id = ? AND image_path = ? """, (person_id, image_path))
    count = cursor.fetchone()[0]

    # 如果不存在
    if count == 0:
        feature_blob = feature.tobytes()  # 将特征向量转换为二进制格式
        cursor.execute('''
            INSERT INTO features (person_id, image_path, feature_vector)
            VALUES (?, ?, ?)
        ''', (person_id, image_path, feature_blob))
        conn.commit()
        conn.close()
    else:
        print(f"Feature for {person_id} - {image_path} already exists,  skipping")


# 6. 处理每个文件夹,提取特征并保存到数据库
def process_folder(db_path, folder_path, person_id):
    for image_name in os.listdir(folder_path):
        image_path = os.path.join(folder_path, image_name)
        # 避免处理非图片文件
        if image_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            # 防止因图片损坏导致提取特侦失败致使程序中断
            try:
                feature = extract_feature(image_path)
                save_feature_to_db(db_path, person_id, image_path, feature)
            except Exception as e:
                print(e)


# 7. 从数据库中获取某个人的平均特征向量
def get_avg_feature(db_path, person_id):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 将所有特征向量转换为 numpy 数组
    features = [np.frombuffer(row[0], dtype=np.float32) for row in rows]
    avg_feature = np.mean(features, axis=0)
    return avg_feature


# 8. 根据欧氏距离排序并重命名图像
def sort_and_rename_images(db_path, out_path, person_id):
    avg_feature = get_avg_feature(db_path, person_id)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT image_path, feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 计算欧氏距离并排序
    distances = []
    for row in rows:
        image_path, feature_blob = row
        feature = np.frombuffer(feature_blob, dtype=np.float32)
        distance = euclidean_distances([feature], [avg_feature])[0][0]
        distances.append((image_path, distance))

    # 按距离排序
    distances.sort(key=lambda x: x[1])

    # 重命名文件
    for idx, (image_path, _) in enumerate(distances):
        new_name = f"{idx:04d}.jpg"  # 按距离排序后的新文件名
        # new_path = os.path.join(folder_path, new_name)
        new_path = rf'{out_path}/{person_id}/{new_name}'
        # 如果目标文件夹不存在,则创建
        os.makedirs(os.path.dirname(new_path), exist_ok=True)
        shutil.copy(image_path, new_path)

        # os.rename(image_path, new_path)


# 9. 主函数
def main():
    # 数据库路径
    db_path = r'D:\FS_project2\Feature_extraction\sql_database\features.db2'
    create_database(db_path)

    # 基础路径
    base_path = r'D:\FS_project2\Feature_extraction\peopel_crop'
    out_path = r'D:\FS_project2\Feature_extraction\out'

    # 第一步:提取特征并保存到数据库
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            process_folder(db_path, folder_path, folder)
            print(f"Processed folder: {folder}")

    # 第二步:排序并重命名图像
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            sort_and_rename_images(db_path, out_path, folder)
            print(f"Sorted and renamed folder: {folder}")


if __name__ == "__main__":
    main()
相关推荐
-SGlow-4 小时前
MySQL相关概念和易错知识点(2)(表结构的操作、数据类型、约束)
linux·运维·服务器·数据库·mysql
明月5665 小时前
Oracle 误删数据恢复
数据库·oracle
♡喜欢做梦7 小时前
【MySQL】深入浅出事务:保证数据一致性的核心武器
数据库·mysql
遇见你的雩风7 小时前
MySQL的认识与基本操作
数据库·mysql
dblens 数据库管理和开发工具7 小时前
MySQL新增字段DDL:锁表全解析、避坑指南与实战案例
数据库·mysql·dblens·dblens mysql·数据库连接管理
weixin_419658317 小时前
MySQL的基础操作
数据库·mysql
不辉放弃8 小时前
ZooKeeper 是什么?
数据库·大数据开发
Goona_8 小时前
拒绝SQL恐惧:用Python+pyqt打造任意Excel数据库查询系统
数据库·python·sql·excel·pyqt
程序员编程指南9 小时前
Qt 数据库连接池实现与管理
c语言·数据库·c++·qt·oracle
幼儿园老大*11 小时前
数据中心-时序数据库InfluxDB
数据库·时序数据库