- resnet分类器训练
import torch import torchvision from torchvision import transforms from torch.utils.data import random_split import torch.nn as nn import torch.optim as optim from torchvision.models import resnet50 # Define the transformation transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load the dataset data = torchvision.datasets.ImageFolder(root=r"D:\train_model\train_data_set", transform=transform) classes_set = data.classes # 保存类别信息到 classes.txt with open('classes.txt', 'w') as f: for class_name in classes_set: f.write(class_name + '\n') # Split the data into train and test sets train_size = int(0.8 * len(data)) test_size = len(data) - train_size train_data, test_data = random_split(data, [train_size, test_size]) # Optionally, you can load the train and test data into data loaders train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True) test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False) # Define the model model = resnet50(pretrained=True) # Replace the last layer num_features = model.fc.in_features model.fc = nn.Linear(num_features, len(classes_set)) # Define the loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Move the model to the device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # Define the number of epochs num_epochs = 10 # Train the model for epoch in range(num_epochs): # Train the model on the training set model.train() train_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader): # Move the data to the device inputs = inputs.to(device) # inputs = inputs.float() labels = labels.to(device) # labels = labels.long() # Zero the parameter gradients optimizer.zero_grad() # Forward + backward + optimize outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # Update the training loss train_loss += loss.item() * inputs.size(0) # Evaluate the model on the test set model.eval() test_loss = 0.0 test_acc = 0.0 with torch.no_grad(): for i, (inputs, labels) in enumerate(test_loader): # Move the data to the device inputs = inputs.to(device) labels = labels.to(device) # Forward outputs = model(inputs) loss = criterion(outputs, labels) # Update the test loss and accuracy test_loss += loss.item() * inputs.size(0) _, preds = torch.max(outputs, 1) test_acc += torch.sum(preds == labels.data) # Print the training and test loss and accuracy train_loss /= len(train_data) test_loss /= len(test_data) test_acc = test_acc.double() / len(test_data) print(f"Epoch [{epoch + 1}/{num_epochs}] Train Loss: {train_loss:.4f} Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f}") # 保存模型参数 torch.save(model.state_dict(), './model/trained_model.pth')
resnet分类训练
E.K.江湖念书人2025-01-03 14:57
相关推荐
bst@微胖子1 天前
PyTorch深度学习框架之基础实战二狮子座明仔1 天前
体验式强化学习:让模型学会“吃一堑长一智“童园管理札记1 天前
【记录模板】大班科学小游戏观察记录(盐主题:《会变魔术的盐》)CelestialYuxin1 天前
A.R.I.S.系统:YOLOx在破碎电子废料分拣中的新探索勾股导航1 天前
蚁群优化算法ppppppatrick1 天前
【深度学习基础篇】手算卷积神经网络:13道经典题全解析(考研/面试必备)石去皿1 天前
文本分类常见面试篇:从 fastText 到 TextCNN 的核心考点全解析狮子座明仔1 天前
REDSearcher:如何用30B参数的小模型,在深度搜索上击败GPT-o3和Gemini?万里鹏程转瞬至1 天前
论文阅读 | SLA:sparse–linear attion视频生成95%稀疏度FLOPs降低20倍肾透侧视攻城狮1 天前
《模型保存加载避坑指南:解锁SavedModel、HDF5与自定义对象的正确姿势》