基于自定义数据集微调SigLIP2-分类任务

本项目基于Google的SigLIP2模型,构建了一个智能xx等级分类系统。通过联合训练策略(对比学习+分类学习),实现了对xx图像的精确等级分类(Grade 2-5),提供AI辅助支持。

一、任务背景

xx等级分类的重要性

xx等级的准确判断对后续方案制定和预后评估至关重要:

- Grade 1 :正常情况,目前不纳入分类内容。

- Grade 2 :等级2

- Grade 3 :等级3

- Grade 4 :等级4

- Grade 5:等级5

技术挑战

图像特征复杂,等级边界模糊,传统方法依赖专家经验,主观性强,需要同时理解图像内容和xx描述。

二、数据集构建

  • 标注格式:JSONL格式,包含图像路径和文本描述
  • 数据分布:多等级xx图像,每张图像配有详细的描述

数据标注

CVAT对图片进行tag标注,自定义标注工具进行jsonl标注文件生成。

复制代码
{"image_path": "images/36_2650.jpg", "text": "Grade 3: Presence of ANY of the following: description."}
{"image_path": "images/36_2675.jpg", "text": "Grade 3: Presence of ANY of the following: description."}
dataset/
├── images/
│   ├── img001.jpg
│   ├── img002.jpg
├── labels.jsonl

数据预处理

复制代码
python
# 数据加载和标签提取
def load_data(jsonl_path, image_dir):
    with open(jsonl_path, "r") as f:
        entries = [json.loads(line.strip()) for line in f]
    data = []
    for entry in entries:
        image_path = os.path.join(image_dir, os.path.basename(entry["image_path"]))
        text = entry["text"]
        label = extract_label(text)  # 从文本中提取等级标签
        if label != -1:
            data.append((image_path, text, label))
    return data
 # 标签映射
 id2grade = {0: "Grade 2", 1: "Grade 3", 2: "Grade 4", 3: "Grade 5"}

三、模型架构

核心模型

  • 基础模型:SigLIP2-Base-Patch16-384
  • 输入尺寸:384×384像素
  • 预训练权重 :Google官方预训练模型,自行到hugging face下载

联合训练架构

复制代码
python
class SigLIP2WithClassifier(nn.Module):
    def __init__(self, base_model, processor, num_classes=4):
        self.siglip = base_model          # SigLIP2主干网络
        self.classifier = nn.Linear(embed_dim, num_classes)  # 分类头
        self.temperature = 0.07           # 对比学习温度参数

损失函数设计

  1. 对比损失(Contrastive Loss)

    • 目标:学习图像-文本对应关系
    • 公式:logits_per_image = (image_embeds @ text_embeds.T) / temperature
  2. 分类损失(Classification Loss)

    • 目标:精确预测烧伤等级
    • 公式:CrossEntropy(classifier(image_embeds), class_labels)
  3. 联合损失

    • 总损失 = 对比损失 + 分类损失

训练配置

复制代码
python
# 训练机器:H100服务器
# 训练参数
epochs = 20
learning_rate = 2e-5
batch_size = 16
device = "cuda"  # GPU训练

模型保存

复制代码
torch.save(model.state_dict(), os.path.join(save_dir, "parkland_siglip2.pt"))

三、推理部署

模型加载

复制代码
python
# 加载训练好的模型
model = SigLIP2WithClassifier(base_model, processor, num_classes=4)
model.load_state_dict(torch.load("parkland_siglip2.pt"))
model.eval()

推理流程

  1. 图像预处理:调整尺寸、标准化

  2. 特征提取:通过SigLIP2获取图像嵌入

  3. 分类预测:通过分类头预测烧伤等级

  4. 结果输出:返回等级概率分布

    加载和预处理图像

    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)

    推理

    image_features = model.siglip.vision_model(pixel_values=inputs["pixel_values"]).pooler_output
    logits = model.classifier(image_features)
    probs = torch.softmax(logits, dim=-1)
    pred = probs.argmax(dim=-1).item()

相关推荐
FairyGirlhub18 分钟前
神经网络的初始化:权重与偏置的数学策略
人工智能·深度学习·神经网络
大写-凌祁4 小时前
零基础入门深度学习:从理论到实战,GitHub+开源资源全指南(2025最新版)
人工智能·深度学习·开源·github
焦耳加热5 小时前
阿德莱德大学Nat. Commun.:盐模板策略实现废弃塑料到单原子催化剂的高值转化,推动环境与能源催化应用
人工智能·算法·机器学习·能源·材料工程
深空数字孪生5 小时前
储能调峰新实践:智慧能源平台如何保障风电消纳与电网稳定?
大数据·人工智能·物联网
wan5555cn5 小时前
多张图片生成视频模型技术深度解析
人工智能·笔记·深度学习·算法·音视频
格林威6 小时前
机器视觉检测的光源基础知识及光源选型
人工智能·深度学习·数码相机·yolo·计算机视觉·视觉检测
今天也要学习吖6 小时前
谷歌nano banana官方Prompt模板发布,解锁六大图像生成风格
人工智能·学习·ai·prompt·nano banana·谷歌ai
Hello123网站6 小时前
glean-企业级AI搜索和知识发现平台
人工智能·产品运营·ai工具
AKAMAI6 小时前
Queue-it 为数十亿用户增强在线体验
人工智能·云原生·云计算
索迪迈科技7 小时前
INDEMIND亮相2025科技创变者大会,以机器人空间智能技术解锁具身智能新边界
人工智能·机器人·扫地机器人·空间智能·陪伴机器人