介绍:本期介绍Lightning的迁移学习
一、使用预训练的LightningModule
使用AutoEncoder作为特征提取器,同时其也作为模型的一部分
python
class Encoder(torch.nn.Module):
...
class AutoEncoder(LightningModule):
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
class CIFAR10Classifier(LightningModule):
def __init__(self):
# 初始化预训练权重
self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
self.feature_extractor.freeze()
# 输出是CIFAR10分类
self.classifier = nn.Linear(100, 10)
def forward(self, x):
representations = self.feature_extractor(x)
x = self.classifier(representations)
...
通过上述方法来实现迁移学习
栗子1:ImageNet(计算机视觉)
python
import torchvision.models as models
class ImagenetTransferLearning(LightningModule):
def __init__(self):
super().__init__()
# 初始化一个预训练好的resnet50
backbone = models.resnet50(weights="DEFAULT")
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*layers)
# 使用预训练模型对CIFAR10进行分类,用的是ImageNet的权重
num_target_classes = 10
self.classifier = nn.Linear(num_filters, num_target_classes)
def forward(self, x):
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
x = self.classifier(representations)
...
Finetune(微调),进行训练
python
model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)
进行预测
python
model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()
x = some_images_from_cifar10()
predictions = model(x)
imagenet的预训练模型,在CIFAR10上进行微调,以在CIFAR10上进行预测。在非学术领域,一般会对小数据集进行微调,并对数据集进行预测。一个意思。
栗子2:BERT(自然语言处理)
推荐一个transformer的git:hugging face
python
class BertMNLIFinetuner(LightningModule):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained("bert-base-cased", output_attentions=True)
self.W = nn.Linear(bert.config.hidden_size, 3)
self.num_classes = 3
def forward(self, input_ids, attention_mask, token_type_ids):
h, _, attn = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
h_cls = h[:, 0]
logits = self.W(h_cls)
return logits, attn