PyTorch Lightning快速学习教程三:迁移学习

介绍:本期介绍Lightning的迁移学习

一、使用预训练的LightningModule

使用AutoEncoder作为特征提取器,同时其也作为模型的一部分

python 复制代码
class Encoder(torch.nn.Module):
    ...

class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # 初始化预训练权重
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # 输出是CIFAR10分类
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

通过上述方法来实现迁移学习

栗子1:ImageNet(计算机视觉)

python 复制代码
import torchvision.models as models

class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # 初始化一个预训练好的resnet50
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # 使用预训练模型对CIFAR10进行分类,用的是ImageNet的权重
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        ...

Finetune(微调),进行训练

python 复制代码
model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)

进行预测

python 复制代码
model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

imagenet的预训练模型,在CIFAR10上进行微调,以在CIFAR10上进行预测。在非学术领域,一般会对小数据集进行微调,并对数据集进行预测。一个意思。

栗子2:BERT(自然语言处理)

推荐一个transformer的git:hugging face

python 复制代码
class BertMNLIFinetuner(LightningModule):
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.W = nn.Linear(bert.config.hidden_size, 3)
        self.num_classes = 3

    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        h_cls = h[:, 0]
        logits = self.W(h_cls)
        return logits, attn
相关推荐
taoqick2 小时前
对PosWiseFFN的改进: MoE、PKM、UltraMem
人工智能·pytorch·深度学习
CSDN_PBB3 小时前
[STM32 - 野火] - - - 固件库学习笔记 - - - 十五.设置FLASH的读写保护及解除
笔记·stm32·学习
小喵要摸鱼6 小时前
【Pytorch 库】自定义数据集相关的类
pytorch·python
鸡啄米的时光机7 小时前
vscode的一些实用操作
vscode·学习
Kai HVZ7 小时前
《深度学习》——调整学习率和保存使用最优模型
人工智能·深度学习·学习
守护者1708 小时前
JAVA学习-练习试用Java实现“使用Apache Ignite对大数据进行内存计算和快速筛查”
java·学习
懒大王今天不写代码9 小时前
为什么Pytorch中实例化模型会直接调用forward方法?
人工智能·pytorch·python
weixin_5025398510 小时前
rust学习笔记2-rust的包管理工具Cargo使用
笔记·学习·rust
Francek Chen11 小时前
【现代深度学习技术】卷积神经网络 | 从全连接层到卷积
人工智能·pytorch·深度学习·神经网络·cnn
web_1553427465612 小时前
【合集】Java进阶——Java深入学习的笔记汇总 & 再论面向对象、数据结构和算法、JVM底层、多线程、类加载、
java·笔记·学习