pytorch训练和使用resnet

pytorch训练和使用resnet

使用 CIFAR-10数据集

训练 resnet

resnet-train.py

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# 在CIFAR-10数据集中
# 训练集:包含50000张图像,用于训练模型。
# 测试集:包含10000张图像,用于评估模型的性能。
TRAIN_SIZE=50000
TEST_SIZE=10000

# 批量大小
BATCH_SIZE=128

# 数据预处理
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 使用预训练的ResNet模型 , 不从默认url下载预训练的模型
model = torchvision.models.resnet18(weights=None)
# 从当前路径加载预训练权重
model_path = './model/resnet18-f37072fd.pth'
model.load_state_dict(torch.load(model_path))

# 修改最后一层以适应CIFAR-10的10个类别
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)

# 将模型移到GPU(如果有)
if torch.cuda.is_available() :
    print('Using GPU')
    device = torch.device("cuda:0")
else :
    print('Using CPU')
    device = torch.device("cpu")   

model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# 训练网络
num_epochs = 50

print('start Training')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    #总迭代次数 = 训练集大小 / 批量大小 =  向上取整(TRAIN_SIZE=50000 / BATCH_SIZE=128) = 391 次循环
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播 + 向后传播 + 优化
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 打印统计信息
        running_loss += loss.item()
        if i % 100 == 99:    # 每100个小批量打印一次
            print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

    # 更新学习率
    scheduler.step()

print('Finished Training')

# 测试网络
model.eval()
correct = 0
total = 0
with torch.no_grad():
    # 总迭代次数 = 测试集 / 批量大小 向上取整(TEST_SIZE=10000/BATCH_SIZE=128) = 79 次循环
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy_test = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy_test:.2f}%')

# [Epoch 50, Batch 300] loss: 0.142
# Finished Training
# Accuracy of the network on the 10000 test images: 84.53%


# 准确率>0.8保存模型
if(accuracy_test > 0.8):
    print("Accuracy  > 0.8 ,save model")
    model_path = './model/trained_resnet18_cifar10.pth'
    torch.save(model.state_dict(), model_path)
    print(f'Model saved to {model_path}')

使用训练后的 resnet

评估数据

1.jpeg :

2.jpeg:

restnet-eval.py

python 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image

# 模型路径
model_path = './model/trained_resnet18_cifar10.pth'

# 类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 调整图像大小为32x32
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])

# 加载预训练的ResNet模型
model = torchvision.models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
model.load_state_dict(torch.load(model_path))
model.eval()  # 设置模型为评估模式

# 将模型移到GPU(如果有)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def predict_image(image_path):
    # 加载并预处理图像
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # 添加批次维度
    image = image.to(device)

    # 进行预测
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs.data, 1)

    # 输出预测结果
    predicted_class = classes[predicted.item()]
    print(f'Predicted class: {predicted_class}')

# img is in classes
predict_image('./data/1.jpeg')

# img is not in classes
predict_image('./data/2.jpeg')
相关推荐
Codebee2 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º2 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys2 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56782 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子2 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能3 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144873 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile3 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算
人工不智能5773 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
盟接之桥3 小时前
盟接之桥说制造:引流品 × 利润品,全球电商平台高效产品组合策略(供讨论)
大数据·linux·服务器·网络·人工智能·制造