使用深度学习模型对视频进行聚类分析-Pytorch、Skleran、Matplotlib

python 复制代码
from sklearn.datasets import make_circles
from sklearn.cluster import KMeans, DBSCAN, SpectralClustering, Birch, MeanShift, AgglomerativeClustering
from sklearn.metrics import silhouette_score, silhouette_samples
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time, os
import functools
import matplotlib.cm as cm

from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
from torchvision.models.video import mvit_v1_b, MViT_V1_B_Weights
 
 
def cluster_test(model_name, model, X, clusters_list = [2,3,4,5,6,7]):
    for n_clusters in clusters_list:
        if hasattr(model, "n_clusters"):
            model.set_params(n_clusters = n_clusters)
        elif len(clusters_list) >= 2 and n_clusters == clusters_list[1]:
            print("{} do not have parameter 'n_clusters', return automatically.".format(model_name))
            return
        
        fig, (ax1, ax2) = plt.subplots(1, 2)
        fig.set_size_inches(18, 7)
 
        ax1.set_xlim([-0.1, 1])
        ax1.set_ylim([0, X.shape[0] + (n_clusters + 1) * 10])
 
        clusterer, t = cluster_function(model_name, model, X)
        cluster_labels = clusterer.labels_
        silhouette_avg = silhouette_score(X, cluster_labels)
        print("For n_clusters = ", n_clusters, " the average silhoutte_score is ", silhouette_avg)
        sample_silhouette_values = silhouette_samples(X, cluster_labels)
 
        y_lower = 10
        for i in range(n_clusters):
            ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
            ith_cluster_silhouette_values.sort()
            size_cluster_i = ith_cluster_silhouette_values.shape[0]
            y_upper = y_lower + size_cluster_i
            color = cm.nipy_spectral(float(i) / n_clusters)
 
            ax1.fill_betweenx(np.arange(y_lower, y_upper), ith_cluster_silhouette_values, facecolor = color, alpha = 0.7)
            ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
            y_lower = y_upper + 10
 
        ax1.set_title("The silhouette plot for the various clusters")
        ax1.set_xlabel("The silhouette coefficient values")
        ax1.set_ylabel("Cluster label")
        ax1.axvline(x = silhouette_avg, color = 'red', linestyle = "--")
        ax1.set_yticks([])
        ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
        
        colors = cm.nipy_spectral(cluster_labels.astype(float) / n_clusters)
        ax2.scatter(pca_result[:,0], pca_result[:,1], marker = 'o', s = 8, c = colors)
        if hasattr(clusterer, 'cluster_centers_'):
            centers = clusterer.cluster_centers_
            ax2.scatter(centers[:, 0], centers[:, 1], marker = 'x', c = 'red', alpha = 1, s = 200)
        ax2.text(.99, .01, ('%.2fs' % (t)).lstrip('0'), transform=plt.gca().transAxes, size=12,horizontalalignment='right')
        ax2.set_title("The visualization of the clustered data")
        ax2.set_xlabel("Feature space for the 1st feature")
        ax2.set_ylabel("Feature space for the 2nd feature")
 
        plt.suptitle("Silhouette analysis for {} clustering on sample data with n_clusters = {} ({})".format(model_name, n_clusters, silhouette_avg), fontsize = 14, fontweight="bold")
        plt.show()
 
        
def time_cost(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        t0 = time.time()
        func(*args, **kwargs)
        t1 = time.time()
        return func(*args, **kwargs), t1 - t0
    return wrapper
 
 
@time_cost
def cluster_function(model_name, model, data):
    model = model.fit(data)
    return model
 
 
def load_data(file_dir):
    assert file_dir != ''
    x = []
    for item in os.listdir(file_dir):
        data = video_feature(os.path.join(file_dir, item))
        x.append(np.array(data))
    x = np.array(x)
    print(x.shape)
    #x = df.values
    pca = PCA(n_components=6)
    pca_result = pca.fit_transform(x)
    return x, pca_result
 
def video_feature(file_path):
    vid, _, _ = read_video(file_path, output_format="TCHW")
    vid = vid[:16]
    # Step 1: Initialize model with the best available weights
    # weights = R3D_18_Weights.DEFAULT
    # model = r3d_18(weights=weights)
    weights = MViT_V1_B_Weights.DEFAULT
    model = mvit_v1_b(weights)
    model.eval()
    # Step 2: Initialize the inference transforms
    preprocess = weights.transforms()
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(vid).unsqueeze(0)
    prediction = model(batch).squeeze(0)
    prediction = prediction.cpu().detach().numpy()
    # print(prediction.shape)
    
    return prediction
    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    label = prediction.argmax().item()
    score = prediction[label].item()
    category_name = weights.meta["categories"][label]
    print(f"{category_name}: {100 * score}%")
 

if __name__ == "__main__":
    path = r"/home/markjhon/Common/Dataset/Infant/Left_hand"

    # 加载数据
    x, pca_result = load_data(path)
    cluster_test("AgglomerativeClustering", AgglomerativeClustering(), x)

相关推荐
冷眼看人间恩怨2 分钟前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_883041084 分钟前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
AI极客菌1 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭1 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^1 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
Power20246662 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k2 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫2 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法
沉下心来学鲁班2 小时前
复现LLM:带你从零认识语言模型
人工智能·语言模型
数据猎手小k2 小时前
AndroidLab:一个系统化的Android代理框架,包含操作环境和可复现的基准测试,支持大型语言模型和多模态模型。
android·人工智能·机器学习·语言模型