6-2 pytorch中训练模型的3种方法

Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。(养成自己的习惯)

有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。

下面以minist数据集的多分类模型的训练为例,演示这3种训练模型的风格。

其中类形式训练循环我们同时演示torchkeras.KerasModel和torchkeras.LightModel两种示范。

准备数据

python 复制代码
transform = transforms.Compose([transforms.ToTensor()])

ds_train = torchvision.datasets.MNIST(root="./data/mnist/",train=True,download=True,transform=transform)
ds_val = torchvision.datasets.MNIST(root="./data/mnist/",train=False,download=True,transform=transform)

dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=4)

print(len(ds_train))
print(len(ds_val))
python 复制代码
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

#查看部分样本
from matplotlib import pyplot as plt 

plt.figure(figsize=(8,8)) 
for i in range(9):
    img,label = ds_train[i] 
    img = torch.squeeze(img) # 删除为1的维度
    ax=plt.subplot(3,3,i+1)
    ax.imshow(img.numpy())
    ax.set_title("label = %d"%label)
    ax.set_xticks([])
    ax.set_yticks([]) 
plt.show()

一、脚本风格

脚本风格的训练循环非常常见。

python 复制代码
net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,10))

print(net)

代码量较多,可以查看最下方链接对应的notebook。

二、函数风格

该风格在脚本形式上做了进一步的函数封装。

python 复制代码
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,10)]
        )
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
net = Net()
print(net)

代码量较多,可以查看最下方链接对应的notebook。

三、类风格

此处使用**torchkeras.KerasModel(其源码其实就是脚本风格中的代码)**高层次API接口中的fit方法训练模型。

使用该形式训练模型非常简洁明了。

先构建模型,同一二。

python 复制代码
from torchkeras import KerasModel 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
            nn.MaxPool2d(kernel_size = 2,stride = 2),
            nn.Dropout2d(p = 0.1),
            nn.AdaptiveMaxPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(64,32),
            nn.ReLU(),
            nn.Linear(32,10)]
        )
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x
    
net = Net() 

print(net)

使用kerasModel:

python 复制代码
from torchmetrics import Accuracy

model = KerasModel(net,
                   loss_fn=nn.CrossEntropyLoss(),
                   metrics_dict = {"acc":Accuracy(task='multiclass',num_classes=10)},
                   optimizer = torch.optim.Adam(net.parameters(),lr = 0.01)  )

model.fit(
    train_data = dl_train,
    val_data= dl_val,
    epochs=10,
    patience=3,
    monitor="val_acc", 
    mode="max",
    plot=True,
    cpu=True
)

训练过程:

其实编码训练代码按照自己的习惯即可,不必要按照以上三种方式。

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

相关推荐
ai大师3 分钟前
(附代码及图示)Multi-Query 多查询策略详解
python·langchain·中转api·apikey·中转apikey·免费apikey·claude4
海盗儿12 分钟前
Attention Is All You Need (Transformer) 以及Transformer pytorch实现
pytorch·深度学习·transformer
GIS小天24 分钟前
AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年6月7日第101弹
人工智能·算法·机器学习·彩票
小小爬虾24 分钟前
关于datetime获取时间的问题
python
阿部多瑞 ABU33 分钟前
主流大语言模型安全性测试(三):阿拉伯语越狱提示词下的表现与分析
人工智能·安全·ai·语言模型·安全性测试
cnbestec40 分钟前
Xela矩阵三轴触觉传感器的工作原理解析与应用场景
人工智能·线性代数·触觉传感器
不爱写代码的玉子1 小时前
HALCON透视矩阵
人工智能·深度学习·线性代数·算法·计算机视觉·矩阵·c#
sbc-study1 小时前
PCDF (Progressive Continuous Discrimination Filter)模块构建
人工智能·深度学习·计算机视觉
EasonZzzzzzz1 小时前
计算机视觉——相机标定
人工智能·数码相机·计算机视觉
猿小猴子1 小时前
主流 AI IDE 之一的 Cursor 介绍
ide·人工智能·cursor