20.微调ResNet-18网络分类热狗数据集(失败版本)

python 复制代码
import os
import torch
import torchvision
from torch import nn
import torchvision.models as models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
##########################################################################################################################
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):
    epochs = range(1, len(train_loss_list) + 1)
    plt.figure(figsize=(4, 3))
    plt.plot(epochs, train_loss_list, label='Train Loss')
    plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')
    plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()
def train_model(model,train_data,test_data,num_epochs):
    train_loss_list = []
    train_acc_list = []
    test_acc_list = []
    for epoch in range(num_epochs):
        total_loss=0
        total_acc_sample=0
        total_samples=0
        loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop1:
            #X=X.reshape(X.shape[0],-1)
            #print(X.shape)
            X=X.to(device)
            y=y.to(device)
            y_hat=model(X)
            loss=CEloss(y_hat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #loss累加
            total_loss+=loss.item()*X.shape[0]
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            total_samples+=X.shape[0]
        test_acc_samples=0
        test_samples=0
        loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")
        for X,y in loop2:
            X=X.to(device)
            y=y.to(device)
            #X=X.reshape(X.shape[0],-1)
            y_hat=model(X)
            y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()
            y_true=y.detach().cpu().numpy()
            test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数
            test_samples+=X.shape[0]
        avg_train_loss=total_loss/total_samples
        avg_train_acc=total_acc_sample/total_samples
        avg_test_acc=test_acc_samples/test_samples
        train_loss_list.append(avg_train_loss)
        train_acc_list.append(avg_train_acc)
        test_acc_list.append(avg_test_acc)
        print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")
    plot_metrics(train_loss_list, train_acc_list, test_acc_list)
    return model
##########################################################################################################################
data_dir=r'./hotdog_dataset/hotdog'
# 使用RGB通道的均值和标准差,以标准化每个通道
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'),transform=train_augs)
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'),transform=test_augs)
train_data=DataLoader(train_imgs,batch_size=16,num_workers=4,shuffle=True)
test_data=DataLoader(test_imgs,batch_size=16,num_workers=4,shuffle=False)
##########################################################################################################################
pretrained_net = models.resnet18(pretrained=True)
#加载预训练模型并且更改最后层
finetune_net=models.resnet50(pretrained=True)
finetune_net.fc=nn.Linear(finetune_net.fc.in_features,2)
nn.init.xavier_normal(finetune_net.fc.weight)
##########################################################################################################################
#开始训练
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
finetune_net.to(device)
CEloss=nn.CrossEntropyLoss()
params_1x = [param for name, param in finetune_net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
optimizer = torch.optim.SGD([
        {'params': params_1x},
        {'params': finetune_net.fc.parameters(), 'lr':0.001 * 10}
    ], lr=0.001, weight_decay=1e-4)
model=train_model(finetune_net,train_data,test_data,num_epochs=10)
相关推荐
2501_924879361 小时前
口罩识别场景误报率↓79%:陌讯多模态融合算法实战解析
人工智能·深度学习·算法·目标检测·智慧城市
万粉变现经纪人1 小时前
如何解决pip安装报错ModuleNotFoundError: No module named ‘keras’问题
人工智能·python·深度学习·scrapy·pycharm·keras·pip
whaosoft-1431 小时前
51c自动驾驶~合集12
人工智能
Chan161 小时前
【智能协同云图库】第七期:基于AI调用阿里云百炼大模型,实现AI图片编辑功能
java·人工智能·spring boot·后端·spring·ai·ai作画
计算机科研圈2 小时前
字节Seed发布扩散语言模型,推理速度达2146 tokens/s,比同规模自回归快5.4倍
人工智能·语言模型·自然语言处理·数据挖掘·开源·字节
Christo32 小时前
TFS-2022《A Novel Data-Driven Approach to Autonomous Fuzzy Clustering》
人工智能·算法·机器学习·支持向量机·tfs
陈哥聊测试2 小时前
Coze开源了!意味着什么?
人工智能·ai·开源·项目管理·项目管理软件
FL16238631292 小时前
室内液体撒漏泄漏识别分割数据集labelme格式2576张1类别
人工智能·深度学习
哈__2 小时前
PromptPilot搭配Doubao-seed-1.6:定制你需要的AI提示prompt
大数据·人工智能·promptpilot