20.微调ResNet-18网络分类热狗数据集(失败版本)

python 复制代码
import os
import torch
import torchvision
from torch import nn
import torchvision.models as models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
##########################################################################################################################
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):
    epochs = range(1, len(train_loss_list) + 1)
    plt.figure(figsize=(4, 3))
    plt.plot(epochs, train_loss_list, label='Train Loss')
    plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')
    plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
def train_model(model,train_data,test_data,num_epochs):
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    for epoch in range(num_epochs):
        total_loss=0
        total_acc_sample=0
        total_samples=0
        loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop1:
            #X=X.reshape(X.shape[0],-1)
            #print(X.shape)
            X=X.to(device)
            y=y.to(device)
            y_hat=model(X)
            loss=CEloss(y_hat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #loss累加
            total_loss+=loss.item()*X.shape[0]
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            total_samples+=X.shape[0]
        test_acc_samples=0
        test_samples=0
        loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop2:
            X=X.to(device)
            y=y.to(device)
            #X=X.reshape(X.shape[0],-1)
            y_hat=model(X)
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            test_samples+=X.shape[0]
        avg_train_loss=total_loss/total_samples
        avg_train_acc=total_acc_sample/total_samples
        avg_test_acc=test_acc_samples/test_samples
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
        test_acc_list.append(avg_test_acc)
        print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")
    plot_metrics(train_loss_list, train_acc_list, test_acc_list)
    return model
##########################################################################################################################
data_dir=r'./hotdog_dataset/hotdog'
# 使用RGB通道的均值和标准差,以标准化每个通道
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=train_augs)
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'),transform=test_augs)
train_data=DataLoader(train_imgs,batch_size=16,num_workers=4,shuffle=True)
test_data=DataLoader(test_imgs,batch_size=16,num_workers=4,shuffle=False)
##########################################################################################################################
pretrained_net = models.resnet18(pretrained=True)
#加载预训练模型并且更改最后层
finetune_net=models.resnet50(pretrained=True)
finetune_net.fc=nn.Linear(finetune_net.fc.in_features,2)
nn.init.xavier_normal(finetune_net.fc.weight)
##########################################################################################################################
#开始训练
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
finetune_net.to(device)
CEloss=nn.CrossEntropyLoss()
params_1x = [param for name, param in finetune_net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
optimizer = torch.optim.SGD([
        {'params': params_1x},
        {'params': finetune_net.fc.parameters(), 'lr':0.001 * 10}
    ], lr=0.001, weight_decay=1e-4)
model=train_model(finetune_net,train_data,test_data,num_epochs=10)
相关推荐
财经资讯数据_灵砚智能1 分钟前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年5月2日
人工智能·python·信息可视化·自然语言处理·ai编程
sali-tec2 分钟前
C# 基于OpenCv的视觉工作流-章58-相机标定
图像处理·人工智能·数码相机·opencv·算法·计算机视觉
一水鉴天3 分钟前
同构异质三表总装体系确立与入表机制闭环验证 20260502(腾讯元宝)
人工智能·算法·机器学习
kalvin_y_liu5 分钟前
人体动作理解和人机共享控制两个研究方向的核心内容
人工智能·具身数据模型
浔川python社5 分钟前
AI 生成视频盛行,会带来哪些利与弊
人工智能
AI科技星5 分钟前
《全域数学》第一部:数术本源·第二卷《算术原本》之十四附录(二)全域数学体系下三大数论猜想的本源推演与哲学阐释【乖乖数学】
人工智能·线性代数·机器学习·量子计算·agi
qcx235 分钟前
拆解 Warp AI Agent(一):类型即协议——23 种 Action 的编译期安全设计
人工智能·安全·ai·agent·源码解析·warp
蔡俊锋6 分钟前
AI进阶运营:从信息爆炸到智能掌控
人工智能·chatgpt·ai进阶运营
weixin_511840477 分钟前
2026年4月23日 Hermes Agent 与 OpenClaw 深度对比研究
人工智能
乱世刀疤8 分钟前
如何写好GPT Image 2 提示词
人工智能