20.微调ResNet-18网络分类热狗数据集(失败版本)

python 复制代码
import os
import torch
import torchvision
from torch import nn
import torchvision.models as models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
##########################################################################################################################
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):
    epochs = range(1, len(train_loss_list) + 1)
    plt.figure(figsize=(4, 3))
    plt.plot(epochs, train_loss_list, label='Train Loss')
    plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')
    plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
def train_model(model,train_data,test_data,num_epochs):
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    for epoch in range(num_epochs):
        total_loss=0
        total_acc_sample=0
        total_samples=0
        loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop1:
            #X=X.reshape(X.shape[0],-1)
            #print(X.shape)
            X=X.to(device)
            y=y.to(device)
            y_hat=model(X)
            loss=CEloss(y_hat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #loss累加
            total_loss+=loss.item()*X.shape[0]
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            total_samples+=X.shape[0]
        test_acc_samples=0
        test_samples=0
        loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop2:
            X=X.to(device)
            y=y.to(device)
            #X=X.reshape(X.shape[0],-1)
            y_hat=model(X)
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            test_samples+=X.shape[0]
        avg_train_loss=total_loss/total_samples
        avg_train_acc=total_acc_sample/total_samples
        avg_test_acc=test_acc_samples/test_samples
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
        test_acc_list.append(avg_test_acc)
        print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")
    plot_metrics(train_loss_list, train_acc_list, test_acc_list)
    return model
##########################################################################################################################
data_dir=r'./hotdog_dataset/hotdog'
# 使用RGB通道的均值和标准差,以标准化每个通道
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=train_augs)
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'),transform=test_augs)
train_data=DataLoader(train_imgs,batch_size=16,num_workers=4,shuffle=True)
test_data=DataLoader(test_imgs,batch_size=16,num_workers=4,shuffle=False)
##########################################################################################################################
pretrained_net = models.resnet18(pretrained=True)
#加载预训练模型并且更改最后层
finetune_net=models.resnet50(pretrained=True)
finetune_net.fc=nn.Linear(finetune_net.fc.in_features,2)
nn.init.xavier_normal(finetune_net.fc.weight)
##########################################################################################################################
#开始训练
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
finetune_net.to(device)
CEloss=nn.CrossEntropyLoss()
params_1x = [param for name, param in finetune_net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
optimizer = torch.optim.SGD([
        {'params': params_1x},
        {'params': finetune_net.fc.parameters(), 'lr':0.001 * 10}
    ], lr=0.001, weight_decay=1e-4)
model=train_model(finetune_net,train_data,test_data,num_epochs=10)
相关推荐
飞哥数智坊8 小时前
从CodeBuddy翻车到MasterGo救场,我的小程序UI终于焕然一新
人工智能
AKAMAI11 小时前
跳过复杂环节:Akamai应用平台让Kubernetes生产就绪——现已正式发布
人工智能·云原生·云计算
新智元12 小时前
阿里王牌 Agent 横扫 SOTA,全栈开源力压 OpenAI!博士级难题一键搞定
人工智能·openai
新智元13 小时前
刚刚,OpenAI/Gemini 共斩 ICPC 2025 金牌!OpenAI 满分碾压横扫全场
人工智能·openai
机器之心13 小时前
OneSearch,揭开快手电商搜索「一步到位」的秘技
人工智能·openai
阿里云大数据AI技术13 小时前
2025云栖大会·大数据AI参会攻略请查收!
大数据·人工智能
YourKing13 小时前
yolov11n.onnx格式模型转换与图像推理
人工智能
sans_13 小时前
NCCL的用户缓冲区注册
人工智能
sans_13 小时前
三种视角下的Symmetric Memory,下一代HPC内存模型
人工智能
算家计算14 小时前
模糊高清修复真王炸!ComfyUI-SeedVR2-Kontext(画质修复+P图)本地部署教程
人工智能·开源·aigc