pytorch训练和使用resnet

pytorch训练和使用resnet

使用 CIFAR-10数据集

训练 resnet

resnet-train.py

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# 在CIFAR-10数据集中
# 训练集:包含50000张图像,用于训练模型。
# 测试集:包含10000张图像,用于评估模型的性能。
TRAIN_SIZE=50000
TEST_SIZE=10000

# 批量大小
BATCH_SIZE=128

# 数据预处理
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 使用预训练的ResNet模型 , 不从默认url下载预训练的模型
model = torchvision.models.resnet18(weights=None)
# 从当前路径加载预训练权重
model_path = './model/resnet18-f37072fd.pth'
model.load_state_dict(torch.load(model_path))

# 修改最后一层以适应CIFAR-10的10个类别
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

# 将模型移到GPU(如果有)
if torch.cuda.is_available() :
    print('Using GPU')
    device = torch.device("cuda:0")
else :
    print('Using CPU')
    device = torch.device("cpu")   

model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# 训练网络
num_epochs = 50

print('start Training')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    #总迭代次数 = 训练集大小 / 批量大小 =  向上取整(TRAIN_SIZE=50000 / BATCH_SIZE=128) = 391 次循环
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播 + 向后传播 + 优化
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 100 == 99:    # 每100个小批量打印一次
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

    # 更新学习率
    scheduler.step()

print('Finished Training')

# 测试网络
model.eval()
correct = 0
total = 0
with torch.no_grad():
    # 总迭代次数 = 测试集 / 批量大小 向上取整(TEST_SIZE=10000/BATCH_SIZE=128) = 79 次循环
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy_test = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy_test:.2f}%')

# [Epoch 50, Batch 300] loss: 0.142
# Finished Training
# Accuracy of the network on the 10000 test images: 84.53%


# 准确率>0.8保存模型
if(accuracy_test > 0.8):
    print("Accuracy  > 0.8 ,save model")
    model_path = './model/trained_resnet18_cifar10.pth'
    torch.save(model.state_dict(), model_path)
    print(f'Model saved to {model_path}')

使用训练后的 resnet

评估数据

1.jpeg :

2.jpeg:

restnet-eval.py

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image

# 模型路径
model_path = './model/trained_resnet18_cifar10.pth'

# 类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 调整图像大小为32x32
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 加载预训练的ResNet模型
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load(model_path))
model.eval()  # 设置模型为评估模式

# 将模型移到GPU(如果有)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def predict_image(image_path):
    # 加载并预处理图像
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # 添加批次维度
    image = image.to(device)

    # 进行预测
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.data, 1)

    # 输出预测结果
    predicted_class = classes[predicted.item()]
    print(f'Predicted class: {predicted_class}')

# img is in classes
predict_image('./data/1.jpeg')

# img is not in classes
predict_image('./data/2.jpeg')
相关推荐
思绪漂移3 分钟前
CodeBuddy AI IDE :Skills 模式
ide·人工智能
居7然15 分钟前
详解监督微调(SFT):大模型指令遵循能力的核心构建方案
人工智能·分布式·架构·大模型·transformer
没有钱的钱仔15 分钟前
神经 网络
深度学习
KKKlucifer21 分钟前
技术漏洞被钻营!Agent 感知伪装借 ChatGPT Atlas 批量输出虚假数据,AI 安全防线面临新挑战
人工智能·安全·chatgpt
oil欧哟23 分钟前
AI 的环保账,训练一个模型要用多少电?
人工智能·chatgpt
执笔论英雄1 小时前
【大模型训练】roll 调用megatron 计算损失函数有,会用到partial
人工智能
小蜜蜂爱编程1 小时前
deep learning简介
人工智能·深度学习
IT_陈寒1 小时前
SpringBoot实战避坑指南:我在微服务项目中总结的12条高效开发经验
前端·人工智能·后端
AI优秘企业大脑1 小时前
需求洞察助力战略规划实现潜在市场机会
大数据·人工智能
Learn Beyond Limits1 小时前
Clustering vs Classification|聚类vs分类
人工智能·算法·机器学习·ai·分类·数据挖掘·聚类