[PyTorch][chapter 52][迁移学习]

前言:

迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一 简介

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2 网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3 特征迁移:

将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

word2vec

1.4 参数迁移:

将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二 Flatten

作用:

输入的向量x [batch, c, w, h]=>[batch, c*w*h]

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023

@author: chengxf2
"""

import torch
from torch import optim,nn

class Flatten(nn.Module):
    
    def __init__(self):
        
        super(Flatten,self).__init__()
        
    
    def forward(self, x):
        
        a = torch.tensor(x.shape[1:])
        #dim 中 input 张量的每一行的乘积。
        shape = torch.prod(a).item()
        #print("\n ---new shape--- ",shape)
        return x.view(-1,shape)

三 迁移学习

torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

from torchvision.models import resnet152

from torchvision.models import resnet18

注意:

现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

|-------|------------|
| 分类器 | 分类类型 |
| 已有分类器 | [猫,狗,鸡,鸭】 |
| 新分类器 | [猫,狗] |

复制代码
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18

from util import Flatten

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    trained_model = resnet152(pretrained=True)
    
    model = nn.Sequential(*list(trained_model.children())[:-1],
        Flatten(),
        nn.Linear(in_features=2048, out_features=5))
    
   
    
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    
  
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
             
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()

参考:

https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

相关推荐
诸葛务农7 分钟前
光刻胶性能核心参数:迪尔参数(A、B、C)
人工智能·材料工程
大千AI助手10 分钟前
Householder变换:线性代数中的镜像反射器
人工智能·线性代数·算法·决策树·机器学习·qr分解·householder算法
许泽宇的技术分享19 分钟前
当 AI Agent 遇上 MCP:微软 Agent Framework 的“瑞士军刀“式扩展之道
人工智能·microsoft
沉迷单车的追风少年21 分钟前
Diffusion Model与视频超分(2):解读字节开源视频增强模型SeedVR2
人工智能·深度学习·aigc·音视频·强化学习·视频生成·视频超分
Victory_orsh40 分钟前
“自然搞懂”深度学习系列(基于Pytorch架构)——03渐入佳境
人工智能·pytorch·深度学习
Fuly10241 小时前
AI 大模型应用中的图像,视频,音频的处理
人工智能·音视频
掘金安东尼1 小时前
Cursor 2.0 转向多智能体 AI 编程,并发布 Composer 模型
人工智能
Small___ming1 小时前
【人工智能数学基础】如何理解方差与协方差?
人工智能·概率论
好家伙VCC1 小时前
**发散创新:AI绘画编程探索与实践**随着人工智能技术的飞速发展,AI绘
java·人工智能·python·ai作画
兔兔爱学习兔兔爱学习1 小时前
2025年语音识别(ASR)与语音合成(TTS)技术趋势分析对比
人工智能·语音识别