本文介绍如何使用huggingface的 vit-base-patch16-224
模型进行图片分类任务:以flower分类数据集为例。
1.加载数据集
首先我们导入相关依赖,下载数据集以及模型文件,从本地文件中加载所需的数据集,加载数据处理器,并对数据集进行相关的处理以适应模型的输入,最终生成数据train_dataloader和val_dataloader,同时我们生成id2label和label2id。
python
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import ViTImageProcessor
import torch
datasets = load_dataset("imagefolder", data_dir='flower_images')
datasets = datasets['train'].train_test_split(0.2)
feature_extractor = ViTImageProcessor.from_pretrained('vit-base-patch16-224')
def transforms(examples):
# 处理图像数据
examples["pixel_values"] = [feature_extractor(images=image, return_tensors="pt").get("pixel_values")[0] for image in
examples["image"]]
# 处理标签
examples["labels"] = torch.LongTensor(examples["label"])
return examples
def data_collator(features):
"""
将给定的样本列表合并成一个批次。
参数:
features: 一个字典列表,每个字典对应一个样本,包含处理过的像素值和标签。
返回:
一个字典,包含合并后的批次数据。
"""
# 从 features 中提取 pixel_values,然后使用 torch.stack 将它们堆叠成一个批次
pixel_values = torch.stack([torch.Tensor(feature['pixel_values']) for feature in features])
# 对于标签,我们同样堆叠它们,但需要确保标签是张量形式
labels = torch.tensor([feature['labels'] for feature in features], dtype=torch.long)
# 返回一个字典,包含合并后的批次数据
return {
'pixel_values': pixel_values,
'labels': labels
}
# 对数据集应用转换
encoded_dataset = datasets.map(transforms, remove_columns=['image', 'label'], batched=True)
train_loader = DataLoader(encoded_dataset['train'], shuffle=True, batch_size=8, collate_fn=data_collator)
val_loader = DataLoader(encoded_dataset['test'], batch_size=8, collate_fn=data_collator)
fine_labels = datasets['train'].features['label'].names
# 构建 label2id 和 id2label 映射
label2id = {label: idx for idx, label in enumerate(fine_labels)}
id2label = {idx: label for idx, label in enumerate(fine_labels)}
编写训练循环
定义损失函数,优化器等并设置在gpu进行训练,保存在验证集上表现最好的模型。
python
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers import ViTForImageClassification
from torch import nn
model = ViTForImageClassification.from_pretrained('vit-base-patch16-224', num_labels=5, ignore_mismatched_sizes=True,label2id=label2id, id2label=id2label)
# 定义参数
num_epochs = 5 # 训练的轮数
batch_size = 8 # 批量大小
learning_rate = 1e-4 # 学习率
# 优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 损失函数
criterion = nn.CrossEntropyLoss()
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 训练函数
def train_one_epoch(model, data_loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct_predictions = 0
for batch in data_loader:
inputs, labels = batch["pixel_values"], batch["labels"]
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs.logits, labels) # 使用outputs.logits代替outputs
loss.backward()
optimizer.step()
running_loss += loss.item()
_, preds = torch.max(outputs.logits, dim=1) # 同样使用outputs.logits获取预测结果
correct_predictions += torch.sum(preds == labels)
epoch_loss = running_loss / len(data_loader.dataset)
epoch_acc = correct_predictions.double() / len(data_loader.dataset)
return epoch_loss, epoch_acc
def evaluate(model, data_loader, criterion, device):
model.eval()
running_loss = 0.0
correct_predictions = 0
with torch.no_grad():
for batch in data_loader:
inputs, labels = batch["pixel_values"], batch["labels"]
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs.logits, labels) # 使用outputs.logits代替outputs进行损失计算
running_loss += loss.item()
_, preds = torch.max(outputs.logits, dim=1) # 使用outputs.logits获取预测结果
correct_predictions += torch.sum(preds == labels)
epoch_loss = running_loss / len(data_loader.dataset)
epoch_acc = correct_predictions.double() / len(data_loader.dataset)
return epoch_loss, epoch_acc
# 初始化最高验证准确率为0,用于比较
best_val_acc = 0.0
for epoch in range(num_epochs):
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device
)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
print(
f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}"
)
# 检查当前周期的验证准确率是否为最高
if val_acc > best_val_acc:
best_val_acc = val_acc # 更新最高验证准确率
best_model_path = f"best_model.pth" # 定义保存最佳模型的路径
torch.save(model.state_dict(), best_model_path) # 保存最佳模型的状态字典
print(f"New best model saved to {best_model_path} with Validation Accuracy: {best_val_acc:.4f}")
以上为全部代码,之后会介绍如何对模型进行微调。