[PyTorch][chapter 52][迁移学习]

前言:

迁移学习(Transfer Learning)是一种机器学习方法,它通过将一个领域中的知识和经验迁移到另一个相关领域中,来加速和改进新领域的学习和解决问题的能力。

这里面主要结合前面ResNet18 例子,详细讲解一下迁移学习的流程


一 简介

迁移学习可以通过以下几种方式实现:

1.1 基于预训练模型的迁移:

将已经在大规模数据集上预训练好的模型(如BERT、GPT等)作为一个通用的特征提取器,然后在新领域的任务上进行微调。

1.2 网络结构迁移:

将在一个领域中训练好的模型的网络结构应用到另一个领域中,并在此基础上进行微调。

1.3 特征迁移:

将在一个领域中训练好的某些特征应用到另一个领域中,并在此基础上进行微调。

word2vec

1.4 参数迁移:

将在一个领域中训练好的模型的参数应用到另一个领域中,并在此基础上进行微调。

本文主要例子用的是 参数迁移


二 Flatten

作用:

输入的向量x [batch, c, w, h]=>[batch, c*w*h]

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 15:11:35 2023

@author: chengxf2
"""

import torch
from torch import optim,nn

class Flatten(nn.Module):
    
    def __init__(self):
        
        super(Flatten,self).__init__()
        
    
    def forward(self, x):
        
        a = torch.tensor(x.shape[1:])
        #dim 中 input 张量的每一行的乘积。
        shape = torch.prod(a).item()
        #print("\n ---new shape--- ",shape)
        return x.view(-1,shape)

三 迁移学习

torchvision 已经提供好了一些分类器 resnet18,resnet152, 利用其训练好的参数,把最后的分类类型更改掉。

from torchvision.models import resnet152

from torchvision.models import resnet18

注意:

现有分类器分类的类型 > = 新分类器类型,再做transfer.

才能取得好的效果.

|-------|------------|
| 分类器 | 分类类型 |
| 已有分类器 | [猫,狗,鸡,鸭】 |
| 新分类器 | [猫,狗] |

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 14:56:35 2023

@author: chengxf2
"""

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from PokeDataset import Pokemon
from torchvision.models import resnet152
from torchvision.models import resnet18

from util import Flatten

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    trained_model = resnet152(pretrained=True)
    
    model = nn.Sequential(*list(trained_model.children())[:-1],
        Flatten(),
        nn.Linear(in_features=2048, out_features=5))
    
   
    
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train loss'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    
  
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
             
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()

参考:

https://blog.csdn.net/qq_44089890/article/details/130460700

课时107 迁移学习实战_哔哩哔哩_bilibili

相关推荐
车载诊断技术11 分钟前
基于新一代电子电器架构的SOA服务设计方法
人工智能·架构·汽车·计算机外设·ecu故障诊断指南
Luzem031913 分钟前
使用朴素贝叶斯对自定义数据集进行分类
人工智能·机器学习
小菜鸟博士14 分钟前
手撕Vision Transformer -- Day1 -- 基础原理
人工智能·深度学习·学习·算法·面试
找方案28 分钟前
智慧城市(城市大脑)建设方案
人工智能·智慧城市·城市大脑
老艾的AI世界34 分钟前
AI定制祝福视频,广州塔、动态彩灯、LED表白,直播互动新玩法(附下载链接)
图像处理·人工智能·深度学习·神经网络·目标检测·机器学习·ai·ai视频·ai视频生成·ai视频制作
灰灰老师1 小时前
数据分析系列--[11] RapidMiner,K-Means聚类分析(含数据集)
人工智能·算法·机器学习·数据挖掘·数据分析·kmeans·rapidminer
kyle~1 小时前
机器学习--概览
人工智能·机器学习
追求源于热爱!2 小时前
记4(可训练对象+自动求导机制+波士顿房价回归预测
图像处理·人工智能·算法·机器学习·回归
前端达人2 小时前
「AI学习笔记」深度学习进化史:从神经网络到“黑箱技术”(三)
人工智能·笔记·深度学习·神经网络·学习
AIGC大时代2 小时前
对比DeepSeek、ChatGPT和Kimi的学术写作撰写引言能力
数据库·论文阅读·人工智能·chatgpt·数据分析·prompt