20.微调ResNet-18网络分类热狗数据集(失败版本)

python 复制代码
import os
import torch
import torchvision
from torch import nn
import torchvision.models as models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
##########################################################################################################################
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):
    epochs = range(1, len(train_loss_list) + 1)
    plt.figure(figsize=(4, 3))
    plt.plot(epochs, train_loss_list, label='Train Loss')
    plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')
    plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
def train_model(model,train_data,test_data,num_epochs):
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    for epoch in range(num_epochs):
        total_loss=0
        total_acc_sample=0
        total_samples=0
        loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop1:
            #X=X.reshape(X.shape[0],-1)
            #print(X.shape)
            X=X.to(device)
            y=y.to(device)
            y_hat=model(X)
            loss=CEloss(y_hat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #loss累加
            total_loss+=loss.item()*X.shape[0]
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            total_samples+=X.shape[0]
        test_acc_samples=0
        test_samples=0
        loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop2:
            X=X.to(device)
            y=y.to(device)
            #X=X.reshape(X.shape[0],-1)
            y_hat=model(X)
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            test_samples+=X.shape[0]
        avg_train_loss=total_loss/total_samples
        avg_train_acc=total_acc_sample/total_samples
        avg_test_acc=test_acc_samples/test_samples
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
        test_acc_list.append(avg_test_acc)
        print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")
    plot_metrics(train_loss_list, train_acc_list, test_acc_list)
    return model
##########################################################################################################################
data_dir=r'./hotdog_dataset/hotdog'
# 使用RGB通道的均值和标准差,以标准化每个通道
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=train_augs)
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'),transform=test_augs)
train_data=DataLoader(train_imgs,batch_size=16,num_workers=4,shuffle=True)
test_data=DataLoader(test_imgs,batch_size=16,num_workers=4,shuffle=False)
##########################################################################################################################
pretrained_net = models.resnet18(pretrained=True)
#加载预训练模型并且更改最后层
finetune_net=models.resnet50(pretrained=True)
finetune_net.fc=nn.Linear(finetune_net.fc.in_features,2)
nn.init.xavier_normal(finetune_net.fc.weight)
##########################################################################################################################
#开始训练
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
finetune_net.to(device)
CEloss=nn.CrossEntropyLoss()
params_1x = [param for name, param in finetune_net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
optimizer = torch.optim.SGD([
        {'params': params_1x},
        {'params': finetune_net.fc.parameters(), 'lr':0.001 * 10}
    ], lr=0.001, weight_decay=1e-4)
model=train_model(finetune_net,train_data,test_data,num_epochs=10)
相关推荐
枫叶林FYL1 分钟前
【Python高级工程与架构实战】项目四:生产级LLM Agent框架:基于PydanticAI的类型安全企业级实现
人工智能·python·自然语言处理
feasibility.4 分钟前
OpenClaw+LibTV视频生成实测(含安装+配置+分析):ai生成工作流很规范,但画面在“打架“
人工智能·aigc·音视频·内容运营·短剧·openclaw·libtv
I_Am_Zou11 分钟前
cloneman-ai技术解析:可落地的AI数字分身平台设计与实现
人工智能
老刘干货12 分钟前
Prompt工程全解·第二篇:骨架搭建——构建高可用Prompt的“四要素”模型
人工智能·技术人
夕小瑶12 分钟前
AI音乐的下半场,是 Vibe music!让Claude Code开口唱歌
人工智能
才盛智能科技15 分钟前
麦粒空间和元K聚合平台正式签约,全面启动流量合作
大数据·人工智能·元k聚合·麦粒空间
V搜xhliang024621 分钟前
基于¹⁸F-FDG PET/CT的深度学习-影像组学-临床模型预测非小细胞肺癌脉管侵犯的价值
大数据·人工智能·python·深度学习·机器学习·机器人
LaughingZhu21 分钟前
Product Hunt 每日热榜 | 2026-04-11
人工智能·chatgpt
XuecWu321 分钟前
原生多模态颠覆Scaling Law?解读语言“参数需求型”与视觉“数据需求型”核心差异
人工智能·深度学习·算法·计算机视觉·语言模型
华农DrLai22 分钟前
怎么用大模型生成推荐的训练数据?Data Augmentation怎么做?
数据库·人工智能·大模型·nlp·prompt