[PyTorch][chapter 52][迁移学习]

前言:

迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一 简介

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2 网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3 特征迁移:

将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

word2vec

1.4 参数迁移:

将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二 Flatten

作用:

输入的向量x [batch, c, w, h]=>[batch, c*w*h]

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023

@author: chengxf2
"""

import torch
from torch import optim,nn

class Flatten(nn.Module):
    
    def __init__(self):
        
        super(Flatten,self).__init__()
        
    
    def forward(self, x):
        
        a = torch.tensor(x.shape[1:])
        #dim 中 input 张量的每一行的乘积。
        shape = torch.prod(a).item()
        #print("\n ---new shape--- ",shape)
        return x.view(-1,shape)

三 迁移学习

torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

from torchvision.models import resnet152

from torchvision.models import resnet18

注意:

现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

|-------|------------|
| 分类器 | 分类类型 |
| 已有分类器 | [猫,狗,鸡,鸭】 |
| 新分类器 | [猫,狗] |

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18

from util import Flatten

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    trained_model = resnet152(pretrained=True)
    
    model = nn.Sequential(*list(trained_model.children())[:-1],
        Flatten(),
        nn.Linear(in_features=2048, out_features=5))
    
   
    
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    
  
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
             
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()

参考:

https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

相关推荐
Echo``几秒前
1:OpenCV—图像基础
c++·图像处理·人工智能·opencv·算法·计算机视觉·视觉检测
FL171713142 分钟前
MATLAB机器人系统工具箱中的loadrobot和importrobot
人工智能·matlab·机器人
夏天是冰红茶19 分钟前
图像处理:预览并绘制图像细节
图像处理·人工智能·opencv
点云SLAM35 分钟前
Python中in和is关键字详解和使用
开发语言·人工智能·python·python学习·in和is关键字·python中for循环
后知后觉39 分钟前
深度学习-最简单的Demo-直接运行
人工智能·深度学习
说私域43 分钟前
基于开源链动2+1模式AI智能名片S2B2C商城小程序的低集中度市场运营策略研究
人工智能·小程序·开源·零售
COOCC144 分钟前
激活函数全解析:定义、分类与 17 种常用函数详解
人工智能·深度学习·神经网络·算法·机器学习·计算机视觉·自然语言处理
武子康1 小时前
大语言模型 09 - 从0开始训练GPT 0.25B参数量 补充知识之数据集 Pretrain SFT RLHF
人工智能·gpt·ai·语言模型·自然语言处理
davysiao1 小时前
AG-UI 协议:重构多模态交互,开启智能应用新纪元
人工智能
沃洛德.辛肯1 小时前
PyTorch 的 F.scaled_dot_product_attention 返回Nan
人工智能·pytorch·python