使用vision transformer进行花朵图片分类

本文介绍如何使用huggingface的 vit-base-patch16-224模型进行图片分类任务:以flower分类数据集为例。

1.加载数据集

首先我们导入相关依赖,下载数据集以及模型文件,从本地文件中加载所需的数据集,加载数据处理器,并对数据集进行相关的处理以适应模型的输入,最终生成数据train_dataloader和val_dataloader,同时我们生成id2label和label2id。

python 复制代码
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor
import torch

datasets = load_dataset("imagefolder", data_dir='flower_images')
datasets = datasets['train'].train_test_split(0.2)


feature_extractor = ViTImageProcessor.from_pretrained('vit-base-patch16-224')


def transforms(examples):
    # 处理图像数据
    examples["pixel_values"] = [feature_extractor(images=image, return_tensors="pt").get("pixel_values")[0] for image in
                                examples["image"]]

    # 处理标签
    examples["labels"] = torch.LongTensor(examples["label"])

    return examples


def data_collator(features):
    """
    将给定的样本列表合并成一个批次。

    参数:
    features: 一个字典列表,每个字典对应一个样本,包含处理过的像素值和标签。

    返回:
    一个字典,包含合并后的批次数据。
    """
    # 从 features 中提取 pixel_values,然后使用 torch.stack 将它们堆叠成一个批次
    pixel_values = torch.stack([torch.Tensor(feature['pixel_values']) for feature in features])
    # 对于标签,我们同样堆叠它们,但需要确保标签是张量形式
    labels = torch.tensor([feature['labels'] for feature in features], dtype=torch.long)

    # 返回一个字典,包含合并后的批次数据
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }


# 对数据集应用转换
encoded_dataset = datasets.map(transforms, remove_columns=['image', 'label'], batched=True)

train_loader = DataLoader(encoded_dataset['train'], shuffle=True, batch_size=8, collate_fn=data_collator)
val_loader = DataLoader(encoded_dataset['test'], batch_size=8, collate_fn=data_collator)


fine_labels = datasets['train'].features['label'].names

# 构建 label2id 和 id2label 映射
label2id = {label: idx for idx, label in enumerate(fine_labels)}
id2label = {idx: label for idx, label in enumerate(fine_labels)}

编写训练循环

定义损失函数,优化器等并设置在gpu进行训练,保存在验证集上表现最好的模型。

python 复制代码
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers import ViTForImageClassification
from torch import nn
model = ViTForImageClassification.from_pretrained('vit-base-patch16-224', num_labels=5, ignore_mismatched_sizes=True,label2id=label2id, id2label=id2label)
# 定义参数
num_epochs = 5  # 训练的轮数
batch_size = 8  # 批量大小
learning_rate = 1e-4  # 学习率

# 优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# 训练函数
def train_one_epoch(model, data_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct_predictions = 0

    for batch in data_loader:
        inputs, labels = batch["pixel_values"], batch["labels"]
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs.logits, labels)  # 使用outputs.logits代替outputs
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs.logits, dim=1)  # 同样使用outputs.logits获取预测结果
        correct_predictions += torch.sum(preds == labels)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = correct_predictions.double() / len(data_loader.dataset)

    return epoch_loss, epoch_acc


def evaluate(model, data_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0

    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch["pixel_values"], batch["labels"]
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)  # 使用outputs.logits代替outputs进行损失计算

            running_loss += loss.item()
            _, preds = torch.max(outputs.logits, dim=1)  # 使用outputs.logits获取预测结果
            correct_predictions += torch.sum(preds == labels)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = correct_predictions.double() / len(data_loader.dataset)

    return epoch_loss, epoch_acc





# 初始化最高验证准确率为0,用于比较
best_val_acc = 0.0

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device
    )
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    print(
        f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}"
    )

    # 检查当前周期的验证准确率是否为最高
    if val_acc > best_val_acc:
        best_val_acc = val_acc  # 更新最高验证准确率
        best_model_path = f"best_model.pth"  # 定义保存最佳模型的路径
        torch.save(model.state_dict(), best_model_path)  # 保存最佳模型的状态字典
        print(f"New best model saved to {best_model_path} with Validation Accuracy: {best_val_acc:.4f}")

以上为全部代码,之后会介绍如何对模型进行微调。

相关推荐
肖遥Janic16 分钟前
Stable Diffusion绘画 | 插件-Deforum:动态视频生成(上篇)
人工智能·ai·ai作画·stable diffusion
robinfang201924 分钟前
AI在医学领域:Arges框架在溃疡性结肠炎上的应用
人工智能
给自己一个 smile28 分钟前
如何高效使用Prompt与AI大模型对话
人工智能·ai·prompt
魔力之心1 小时前
人工智能与机器学习原理精解【30】
人工智能·机器学习
Hiweir ·1 小时前
NLP任务之文本分类(情感分析)
人工智能·自然语言处理·分类·huggingface
百里香酚兰1 小时前
【AI学习笔记】基于Unity+DeepSeek开发的一些BUG记录&解决方案
人工智能·学习·unity·大模型·deepseek
sp_fyf_20243 小时前
[大语言模型-论文精读] 更大且更可指导的语言模型变得不那么可靠
人工智能·深度学习·神经网络·搜索引擎·语言模型·自然语言处理
肖遥Janic3 小时前
Stable Diffusion绘画 | 插件-Deforum:商业LOGO广告视频
人工智能·ai·ai作画·stable diffusion
我就是全世界4 小时前
一起了解AI的发展历程和AGI的未来展望
人工智能·agi
colorknight4 小时前
1.2.3 HuggingFists安装说明-MacOS安装
人工智能·低代码·macos·huggingface·数据科学·ai agent