Gradio聚类

为了增加页面输出聚类后的结果并方便可视化分析聚类效果,下面是更新后的代码。将Gradio界面中的输出类型改为`gr.outputs.HTML`,并在返回结果时生成HTML格式的聚类结果。```python

python 复制代码
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering

def embedding(sentences):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh')
    model = AutoModel.from_pretrained('BAAI/bge-large-zh').to(device)
    
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
    
    with torch.no_grad():
        model_output = model(**encoded_input)
    
    sentence_embeddings = model_output.last_hidden_state[:, 0]
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).to(device)
    return sentence_embeddings

def cluster_sentences(sentences, cluster_num, eps, min_samples, method, distance_threshold):
    X = embedding(sentences)
    
    if method == 'KMeans':
        kmeans = KMeans(n_clusters=cluster_num)
        kmeans.fit(X.cpu())
        labels = kmeans.labels_
    elif method == 'DBSCAN':
        dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine')
        dbscan.fit(X.cpu())
        labels = dbscan.labels_
    elif method == 'Agglomerative':
        clustering_model = AgglomerativeClustering(n_clusters=None, distance_threshold=distance_threshold)
        clustering_model.fit(X.cpu())
        labels = clustering_model.labels_
    
    cluster_result = dict()
    for i in range(len(sentences)):
        if labels[i] not in cluster_result:
            cluster_result[labels[i]] = []
        cluster_result[labels[i]].append(sentences[i])
    
    # 构建HTML输出
    html = ""
    for label, clustered_sentences in cluster_result.items():
        html += f"<h3>Cluster {label}</h3><ul>"
        for sentence in clustered_sentences:
            html += f"<li>{sentence}</li>"
        html += "</ul>"
    
    return html

def main_interface(sentence_input, cluster_num, eps, min_samples, method, distance_threshold):
    sentences = [sentence.strip() for sentence in sentence_input.split(',')]
    clustered_sentences = cluster_sentences(sentences, cluster_num, eps, min_samples, method, distance_threshold)
    return clustered_sentences

# 使用 Gradio 构建用户界面
iface = gr.Interface(
    fn=main_interface,
    inputs=[
        gr.inputs.Textbox(lines=5, placeholder="请输入句子列表,每个句子用逗号分隔"),
        gr.inputs.Slider(minimum=2, maximum=30, step=1, default=15, label="Cluster Number (KMeans)"),
        gr.inputs.Slider(minimum=0.01, maximum=1.0, step=0.01, default=0.15, label="eps (DBSCAN)"),
        gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2, label="min_samples (DBSCAN)"),
        gr.inputs.Slider(minimum=0.01, maximum=1.0, step=0.01, default=0.85, label="distance_threshold (Agglomerative)"),
        gr.inputs.Radio(choices=["KMeans", "DBSCAN", "Agglomerative"], label="Clustering Method", default="KMeans")
    ],
    outputs=gr.outputs.HTML(),
    title="句子嵌入和聚类",
    description="输入一组用逗号分隔的句子,选择聚类方法和相应参数,进行嵌入和聚类。"
)

# 启动 Gradio 应用
iface.launch()

```

代码解释

  1. **增加HTML格式输出**:
  • 在`cluster_sentences`函数中构建HTML格式的输出,以便在界面上直观显示聚类结果。
  1. **更新输出类型**:
  • 将Gradio界面的输出类型改为`gr.outputs.HTML()`,以支持HTML格式的聚类结果展示。
  1. **构建HTML输出**:
  • 在`cluster_sentences`函数中,根据聚类结果生成HTML格式的字符串,包含每个类别的标题和句子列表。

这样,用户可以在Gradio界面中查看聚类后的详细结果,并通过调整参数进行可视化分析。

python 复制代码
为了增加统计每个聚类下的句子数量以及每个聚类下重复的句子数量,我更新了代码。以下是改进后的版本:

```python
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from collections import Counter

def embedding(sentences):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh')
    model = AutoModel.from_pretrained('BAAI/bge-large-zh').to(device)
    
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
    
    with torch.no_grad():
        model_output = model(**encoded_input)
    
    sentence_embeddings = model_output.last_hidden_state[:, 0]
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).to(device)
    return sentence_embeddings

def cluster_sentences(sentences, cluster_num, eps, min_samples, method, distance_threshold):
    X = embedding(sentences)
    
    if method == 'KMeans':
        kmeans = KMeans(n_clusters=cluster_num)
        kmeans.fit(X.cpu())
        labels = kmeans.labels_
    elif method == 'DBSCAN':
        dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine')
        dbscan.fit(X.cpu())
        labels = dbscan.labels_
    elif method == 'Agglomerative':
        clustering_model = AgglomerativeClustering(n_clusters=None, distance_threshold=distance_threshold)
        clustering_model.fit(X.cpu())
        labels = clustering_model.labels_
    
    cluster_result = dict()
    for i in range(len(sentences)):
        if labels[i] not in cluster_result:
            cluster_result[labels[i]] = []
        cluster_result[labels[i]].append(sentences[i])
    
    # 构建HTML输出
    html = ""
    for label, clustered_sentences in cluster_result.items():
        sentence_count = len(clustered_sentences)
        sentence_counter = Counter(clustered_sentences)
        duplicate_count = sum(count for count in sentence_counter.values() if count > 1)
        html += f"<h3>Cluster {label} - Total Sentences: {sentence_count}, Duplicates: {duplicate_count}</h3><ul>"
        for sentence in clustered_sentences:
            html += f"<li>{sentence}</li>"
        html += "</ul>"
    
    return html

def main_interface(sentence_input, cluster_num, eps, min_samples, method, distance_threshold):
    sentences = [sentence.strip() for sentence in sentence_input.split(',')]
    clustered_sentences = cluster_sentences(sentences, cluster_num, eps, min_samples, method, distance_threshold)
    return clustered_sentences

# 使用 Gradio 构建用户界面
iface = gr.Interface(
    fn=main_interface,
    inputs=[
        gr.inputs.Textbox(lines=5, placeholder="请输入句子列表,每个句子用逗号分隔"),
        gr.inputs.Slider(minimum=2, maximum=30, step=1, default=15, label="Cluster Number (KMeans)"),
        gr.inputs.Slider(minimum=0.01, maximum=1.0, step=0.01, default=0.15, label="eps (DBSCAN)"),
        gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2, label="min_samples (DBSCAN)"),
        gr.inputs.Slider(minimum=0.01, maximum=1.0, step=0.01, default=0.85, label="distance_threshold (Agglomerative)"),
        gr.inputs.Radio(choices=["KMeans", "DBSCAN", "Agglomerative"], label="Clustering Method", default="KMeans")
    ],
    outputs=gr.outputs.HTML(),
    title="句子嵌入和聚类",
    description="输入一组用逗号分隔的句子,选择聚类方法和相应参数,进行嵌入和聚类。"
)

# 启动 Gradio 应用
iface.launch()
```

### 代码解释

1. **增加HTML格式输出**:
    - 在`cluster_sentences`函数中构建HTML格式的输出,以便在界面上直观显示聚类结果。

2. **更新输出类型**:
    - 将Gradio界面的输出类型改为`gr.outputs.HTML()`,以支持HTML格式的聚类结果展示。

3. **统计句子数量和重复句子数量**:
    - 使用`Counter`统计每个聚类中的句子数量和重复句子数量,并在HTML输出中显示这些统计信息。

这样,用户可以在Gradio界面中查看详细的聚类结果,包括每个聚类的句子数量和重复句子数量。通过调整参数,用户可以直观地分析和可视化聚类效果。
python 复制代码
为了满足输出HTML格式的聚类结果的需求,包括每个聚类下包含的句子的详细信息,以及每个聚类的句子数量和每个聚类里重复句子数量统计,我对代码进行了更新:

```python
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from collections import Counter

def embedding(sentences):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh')
    model = AutoModel.from_pretrained('BAAI/bge-large-zh').to(device)
    
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)
    
    with torch.no_grad():
        model_output = model(**encoded_input)
    
    sentence_embeddings = model_output.last_hidden_state[:, 0]
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).to(device)
    return sentence_embeddings

def cluster_sentences(sentences, cluster_num, eps, min_samples, method, distance_threshold):
    X = embedding(sentences)
    
    if method == 'KMeans':
        kmeans = KMeans(n_clusters=cluster_num)
        kmeans.fit(X.cpu())
        labels = kmeans.labels_
    elif method == 'DBSCAN':
        dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine')
        dbscan.fit(X.cpu())
        labels = dbscan.labels_
    elif method == 'Agglomerative':
        clustering_model = AgglomerativeClustering(n_clusters=None, distance_threshold=distance_threshold)
        clustering_model.fit(X.cpu())
        labels = clustering_model.labels_
    
    cluster_result = dict()
    for i in range(len(sentences)):
        if labels[i] not in cluster_result:
            cluster_result[labels[i]] = []
        cluster_result[labels[i]].append(sentences[i])
    
    # 构建HTML输出
    html = "<div>"
    for label, clustered_sentences in cluster_result.items():
        sentence_count = len(clustered_sentences)
        sentence_counter = Counter(clustered_sentences)
        duplicate_count = sum(count for count in sentence_counter.values() if count > 1)
        html += f"<h3>Cluster {label} - Total Sentences: {sentence_count}, Duplicates: {duplicate_count}</h3><ul>"
        for sentence in clustered_sentences:
            count = sentence_counter[sentence]
            html += f"<li>{sentence} (Count: {count})</li>"
        html += "</ul>"
    html += "</div>"
    
    return html

def main_interface(sentence_input, cluster_num, eps, min_samples, method, distance_threshold):
    sentences = [sentence.strip() for sentence in sentence_input.split(',')]
    clustered_sentences = cluster_sentences(sentences, cluster_num, eps, min_samples, method, distance_threshold)
    return clustered_sentences

# 使用 Gradio 构建用户界面
iface = gr.Interface(
    fn=main_interface,
    inputs=[
        gr.inputs.Textbox(lines=5, placeholder="请输入句子列表,每个句子用逗号分隔", label="句子输入"),
        gr.inputs.Slider(minimum=2, maximum=30, step=1, default=15, label="KMeans 聚类数"),
        gr.inputs.Slider(minimum=0.01, maximum=1.0, step=0.01, default=0.15, label="DBSCAN eps"),
        gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2, label="DBSCAN min_samples"),
        gr.inputs.Slider(minimum=0.01, maximum=1.0, step=0.01, default=0.85, label="Agglomerative distance_threshold"),
        gr.inputs.Radio(choices=["KMeans", "DBSCAN", "Agglomerative"], label="聚类方法", default="KMeans")
    ],
    outputs=gr.outputs.HTML(label="聚类结果"),
    title="句子嵌入和聚类",
    description="输入一组用逗号分隔的句子,选择聚类方法和相应参数,进行嵌入和聚类。",
    theme="compact"
)

# 启动 Gradio 应用
iface.launch()
```

### 主要更新内容
1. **嵌入生成**:
    - 使用 `embedding` 函数生成句子嵌入。
2. **聚类方法**:
    - 实现了 `KMeans`、`DBSCAN` 和 `Agglomerative` 三种聚类方法,并允许用户调节相应参数。
3. **输出HTML格式聚类结果**:
    - 生成HTML格式的聚类结果,包括每个聚类的句子数量和重复句子数量统计,并在每个句子旁边显示出现次数。
4. **用户界面**:
    - 使用Gradio创建用户界面,允许用户输入句子列表并调节聚类参数。
    - 界面更加紧凑和美观,便于用户操作。

### 运行方式
将代码保存为Python文件(例如 `clustering_app.py`),然后在终端运行:

```bash
python clustering_app.py
```

这将启动Gradio应用,并在浏览器中打开用户界面,用户可以输入句子并进行聚类分析。
相关推荐
FL16238631295 分钟前
钢材缺陷识别分割数据集labelme格式693张4类别
深度学习
cdut_suye8 分钟前
Linux工具使用指南:从apt管理、gcc编译到makefile构建与gdb调试
java·linux·运维·服务器·c++·人工智能·python
开发者每周简报27 分钟前
微软的AI转型故事
人工智能·microsoft
dundunmm30 分钟前
机器学习之scikit-learn(简称 sklearn)
python·算法·机器学习·scikit-learn·sklearn·分类算法
古希腊掌管学习的神31 分钟前
[机器学习]sklearn入门指南(1)
人工智能·python·算法·机器学习·sklearn
普密斯科技1 小时前
手机外观边框缺陷视觉检测智慧方案
人工智能·计算机视觉·智能手机·自动化·视觉检测·集成测试
四口鲸鱼爱吃盐1 小时前
Pytorch | 利用AI-FGTM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
lishanlu1361 小时前
Pytorch分布式训练
人工智能·ddp·pytorch并行训练
Schwertlilien1 小时前
图像处理-Ch5-图像复原与重建
c语言·开发语言·机器学习
日出等日落1 小时前
从零开始使用MaxKB打造本地大语言模型智能问答系统与远程交互
人工智能·语言模型·自然语言处理