PyTorch Lightning快速学习教程三:迁移学习

介绍:本期介绍Lightning的迁移学习

一、使用预训练的LightningModule

使用AutoEncoder作为特征提取器,同时其也作为模型的一部分

python 复制代码
class Encoder(torch.nn.Module):
    ...

class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # 初始化预训练权重
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # 输出是CIFAR10分类
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

通过上述方法来实现迁移学习

栗子1:ImageNet(计算机视觉)

python 复制代码
import torchvision.models as models

class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # 初始化一个预训练好的resnet50
        backbone = models.resnet50(weights="DEFAULT")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # 使用预训练模型对CIFAR10进行分类,用的是ImageNet的权重
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        ...

Finetune(微调),进行训练

python 复制代码
model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)

进行预测

python 复制代码
model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

imagenet的预训练模型,在CIFAR10上进行微调,以在CIFAR10上进行预测。在非学术领域,一般会对小数据集进行微调,并对数据集进行预测。一个意思。

栗子2:BERT(自然语言处理)

推荐一个transformer的git:hugging face

python 复制代码
class BertMNLIFinetuner(LightningModule):
    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
        self.W = nn.Linear(bert.config.hidden_size, 3)
        self.num_classes = 3

    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        h_cls = h[:, 0]
        logits = self.W(h_cls)
        return logits, attn
相关推荐
吴佳浩6 分钟前
什么是算力?
人工智能·pytorch·llm
数据智能老司机2 天前
PyTorch 深度学习——使用神经网络来拟合数据
pytorch·深度学习
数据智能老司机2 天前
PyTorch 深度学习——用于图像的扩散模型
pytorch·深度学习
数据智能老司机2 天前
PyTorch 深度学习——Transformer 是如何工作的
pytorch·深度学习
数据智能老司机3 天前
PyTorch 深度学习——使用张量表示真实世界数据
pytorch·深度学习
数据智能老司机3 天前
PyTorch 深度学习——它始于一个张量
pytorch·深度学习
Narrastory5 天前
明日香 - Pytorch 快速入门保姆级教程(三)
pytorch·深度学习
Narrastory8 天前
明日香 - Pytorch 快速入门保姆级教程(一)
人工智能·pytorch·深度学习
Narrastory8 天前
明日香 - Pytorch 快速入门保姆级教程(二)
人工智能·pytorch·深度学习
西岸行者13 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习