使用vision transformer进行花朵图片分类

本文介绍如何使用huggingface的 vit-base-patch16-224模型进行图片分类任务:以flower分类数据集为例。

1.加载数据集

首先我们导入相关依赖,下载数据集以及模型文件,从本地文件中加载所需的数据集,加载数据处理器,并对数据集进行相关的处理以适应模型的输入,最终生成数据train_dataloader和val_dataloader,同时我们生成id2label和label2id。

python 复制代码
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor
import torch

datasets = load_dataset("imagefolder", data_dir='flower_images')
datasets = datasets['train'].train_test_split(0.2)


feature_extractor = ViTImageProcessor.from_pretrained('vit-base-patch16-224')


def transforms(examples):
    # 处理图像数据
    examples["pixel_values"] = [feature_extractor(images=image, return_tensors="pt").get("pixel_values")[0] for image in
                                examples["image"]]

    # 处理标签
    examples["labels"] = torch.LongTensor(examples["label"])

    return examples


def data_collator(features):
    """
    将给定的样本列表合并成一个批次。

    参数:
    features: 一个字典列表,每个字典对应一个样本,包含处理过的像素值和标签。

    返回:
    一个字典,包含合并后的批次数据。
    """
    # 从 features 中提取 pixel_values,然后使用 torch.stack 将它们堆叠成一个批次
    pixel_values = torch.stack([torch.Tensor(feature['pixel_values']) for feature in features])
    # 对于标签,我们同样堆叠它们,但需要确保标签是张量形式
    labels = torch.tensor([feature['labels'] for feature in features], dtype=torch.long)

    # 返回一个字典,包含合并后的批次数据
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }


# 对数据集应用转换
encoded_dataset = datasets.map(transforms, remove_columns=['image', 'label'], batched=True)

train_loader = DataLoader(encoded_dataset['train'], shuffle=True, batch_size=8, collate_fn=data_collator)
val_loader = DataLoader(encoded_dataset['test'], batch_size=8, collate_fn=data_collator)


fine_labels = datasets['train'].features['label'].names

# 构建 label2id 和 id2label 映射
label2id = {label: idx for idx, label in enumerate(fine_labels)}
id2label = {idx: label for idx, label in enumerate(fine_labels)}

编写训练循环

定义损失函数,优化器等并设置在gpu进行训练,保存在验证集上表现最好的模型。

python 复制代码
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers import ViTForImageClassification
from torch import nn
model = ViTForImageClassification.from_pretrained('vit-base-patch16-224', num_labels=5, ignore_mismatched_sizes=True,label2id=label2id, id2label=id2label)
# 定义参数
num_epochs = 5  # 训练的轮数
batch_size = 8  # 批量大小
learning_rate = 1e-4  # 学习率

# 优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# 训练函数
def train_one_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct_predictions = 0

    for batch in data_loader:
        inputs, labels = batch["pixel_values"], batch["labels"]
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs.logits, labels)  # 使用outputs.logits代替outputs
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs.logits, dim=1)  # 同样使用outputs.logits获取预测结果
        correct_predictions += torch.sum(preds == labels)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = correct_predictions.double() / len(data_loader.dataset)

    return epoch_loss, epoch_acc


def evaluate(model, data_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0

    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch["pixel_values"], batch["labels"]
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)  # 使用outputs.logits代替outputs进行损失计算

            running_loss += loss.item()
            _, preds = torch.max(outputs.logits, dim=1)  # 使用outputs.logits获取预测结果
            correct_predictions += torch.sum(preds == labels)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = correct_predictions.double() / len(data_loader.dataset)

    return epoch_loss, epoch_acc





# 初始化最高验证准确率为0,用于比较
best_val_acc = 0.0

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device
    )
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    print(
        f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}"
    )

    # 检查当前周期的验证准确率是否为最高
    if val_acc > best_val_acc:
        best_val_acc = val_acc  # 更新最高验证准确率
        best_model_path = f"best_model.pth"  # 定义保存最佳模型的路径
        torch.save(model.state_dict(), best_model_path)  # 保存最佳模型的状态字典
        print(f"New best model saved to {best_model_path} with Validation Accuracy: {best_val_acc:.4f}")

以上为全部代码,之后会介绍如何对模型进行微调。

相关推荐
浩浩乎@25 分钟前
【openGLES】着色器语言(GLSL)
人工智能·算法·着色器
智慧地球(AI·Earth)1 小时前
DeepSeek V3.1 横空出世:重新定义大语言模型的边界与可能
人工智能·语言模型·自然语言处理
金井PRATHAMA1 小时前
语义普遍性与形式化:构建深层语义理解的统一框架
人工智能·自然语言处理·知识图谱
lucky_lyovo2 小时前
大模型部署
开发语言·人工智能·云计算·lua
聚客AI2 小时前
📈超越Prompt Engineering:揭秘高并发AI系统的上下文工程实践
人工智能·llm·agent
北极光SD-WAN组网3 小时前
某电器5G智慧工厂网络建设全解析
人工智能·物联网·5g
十八岁牛爷爷3 小时前
通过官方文档详解Ultralytics YOLO 开源工程-熟练使用 YOLO11实现分割、分类、旋转框检测和姿势估计(附测试代码)
人工智能·yolo·目标跟踪
阿杜杜不是阿木木3 小时前
什么?OpenCV调用cv2.putText()乱码?寻找支持中文的方法之旅
人工智能·opencv·计算机视觉
赴3353 小时前
图像边缘检测
人工智能·python·opencv·计算机视觉
机器视觉知识推荐、就业指导4 小时前
如何消除工业视觉检测中的反光问题
人工智能·计算机视觉·视觉检测