如何使用.pth训练模型

一.使用.pth训练模型的步骤如下:

1.导入必要的库和模型

python 复制代码
import torch
import torchvision.models as models

# 加载预训练模型
model = models.resnet50(pretrained=True)

2.定义数据集和数据加载器

python 复制代码
# 定义数据集和数据加载器
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

3.定义损失函数和优化器

python 复制代码
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

4.训练模型

python 复制代码
# 训练模型
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

5.保存模型

python 复制代码
# 保存模型
torch.save(model.state_dict(), 'model.pth')

二,使用自己训练的.pth模型进行训练的步骤如下:

1.导入必要的库和模型

python 复制代码
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_dataset import MyDataset # 自定义数据集
from my_model import MyModel # 自定义模型

2.设置超参数和路径

python 复制代码
batch_size = 32 # 批大小
num_epochs = 10 # 训练轮数
learning_rate = 0.001 # 学习率
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设置设备
train_data_path = 'train_data/' # 训练数据集路径
test_data_path = 'test_data/' # 测试数据集路径
model_path = 'my_model.pth' # 模型保存路径

3.加载数据集

python 复制代码
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)), # 调整图像大小
    transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(), # 转换为张量
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = MyDataset(train_data_path, train_transforms) # 自定义数据集
test_dataset = MyDataset(test_data_path, test_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 训练集加载器
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 测试集加载器

4.加载模型

python 复制代码
model = MyModel() # 自定义模型
model.load_state_dict(torch.load(model_path)) # 加载.pth模型
model.to(device) # 将模型移动到设备上

5.定义损失函数和优化器

python 复制代码
criterion = torch.nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Adam优化器

6.训练模型

python 复制代码
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

torch.save(model.state_dict(), 'fine_tuned_model.pth') # 保存.pth模型

7.测试模型

python 复制代码
model.eval() # 切换到评估模式
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    print('Accuracy of the network on the test images: {} %'.format(100 * correct / total))
相关推荐
喝过期的拉菲5 分钟前
使用 Pytorch Lightning 时追踪指标和可视化指标
pytorch·可视化·lightning·指标追踪
前端付豪10 分钟前
13、你还在 print 调试🧾?教你写出自己的日志系统
后端·python
摆烂工程师11 分钟前
Claude Code 落地实践的工作简易流程
人工智能·claude·敏捷开发
CoovallyAIHub12 分钟前
YOLOv13都来了,目标检测还卷得动吗?别急,还有这些新方向!
深度学习·算法·计算机视觉
亚马逊云开发者13 分钟前
得心应手:探索 MCP 与数据库结合的应用场景
人工智能
这里有鱼汤15 分钟前
hvPlot:用你熟悉的 Pandas,画出你没见过的炫图
后端·python
大明哥_17 分钟前
100 个 Coze 精品案例 - 小红书爆款图文,单篇点赞 20000+,用 Coze 智能体一键生成有声儿童绘本!
人工智能
聚客AI18 分钟前
🚀拒绝试错成本!企业接入MCP协议的避坑清单
人工智能·掘金·日新计划·mcp
源码站~26 分钟前
基于Flask+Vue的豆瓣音乐分析与推荐系统
vue.js·python·flask·毕业设计·毕设·校园·豆瓣音乐
MessiGo31 分钟前
Python 爬虫实战 | 国家医保
python