CLIP: Contrastive LanguageImage Pre-training
-
paper arxiv.org/abs/2103.00...
Abstract
- 连接文本和图像的预训练模型
Contributions
- zero-shot classification
Related Work
- ConVIRT
- Contrastive learning
Methodology
Overview
(1) Contrastive pre-training
- 模型架构分为两部分,图像编码器和文本编码器,图像编码器可以是比如 resnet50,然后文本编码器可以是 transformer。
- 训练数据是网络社交媒体上搜集的图像文本对。在训练阶段,对于一个batch 的数据,首先通过文本编码器和图像编码器,得到文本和图像的特征,接着将所有的文本和图像特征分别计算内积,就能得到一个矩阵,然后从图像的角度看,行方向就是一个分类器,从文本角度看,列方向也是一个分类器。
- 而由于我们已经知道一个batch中的文本和图像的匹配关系,所以目标函数就是最大化同一对图像和文本特征的内积,也就是矩阵对角线上的元素,而最小化与不相关特征的内积。文章的作者从社交媒体上搜集了有大约4亿对的数据。
(2) (3) Downstream inference
在下游任务测试时,有两种使用CLIP的方法。
第一种,利用文本prompt进行预测,将预测的embedding同类别的embedding进行相似度匹配,实现分类任务;在测试阶段,可以直接将训练好的CLIP用于其他数据集而不需要finetune。和训练阶段类似,首先将需要分类的图像经过编码器得到特征,然后对于目标任务数据集的每一个标签,或者你自己定义的标签,都构造一段对应的文本,如上图中的 dog 会改造成 "A photo of a dog",以此类推 。然后经过编码器得到文本和图像特征,接着将文本特征与图像特征做内积,内积最大对应的标签就是图像的分类结果。这就完成了目标任务上的 zero-shot 分类。
python
"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch
from torchvision.datasets import CIFAR100
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
#cifar每个类别,输入图片,检索匹配的类别
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
"""
Top predictions:
snake: 65.31%
turtle: 12.29%
sweet_pepper: 3.83%
lizard: 1.88%
crocodile: 1.75%
"""
第二种,额外训练linear probe进行预测。通过CLIP的image_encoder得到视觉向量,结合标签做Logistic Regression
python
"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)
def get_features(dataset):
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
features = model.encode_image(images.to(device))
all_features.append(features)
all_labels.append(labels)
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)
# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) # c自定义
classifier.fit(train_features, train_labels)
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")
Experiments
Conclusions
- CLIP 可以说是开辟了 CV+NLP 的多模态表征学习新时代。后面谷歌的ALIGN,微软的Florence,商汤 DeCLIP,快手 EfficientCLIP 都是研究相类似的任务。虽然 CLIP 在小部分任务上 zero-shot 精度一般,但是 CLIP 在多模态的 Encoders 能提供简单而又强大的视觉先验的表征能力。
- CLIP和BERT、GPT、ViT的区别在于,CLIP是多模态的,包含图像处理以及文本处理两个方面内容,而BERT、GPT是单文本模态的,ViT是单图像模态的
Limitations
- 不是和SOTA的比较:以上的数据分析,都是和a linear classifier on top of ResNet-50 features进行比较,大部分的数据集,都有对应的SOTA模型。为了达到SOTA,zero-shot CLIP估计要提高1000x的算力,当前情况不支持;
- 在部分fine-grained分类上表现不佳: a. 前面实验分析发现,模型不能很好的区分cars,species of flowers, 以及variants of aircraft; b. abstract和systematic任务表现不好,比如统计图上object的数量; c. 在训练集中基本不会出现的比较novel的任务,表现欠佳,比如classifying the distance to the nearest car in a photo;
- 训练集中没有出现的图片类型(out-of-distribution),表现不好,比如OCR识别数字效果可以,但是MNIST的准确率只有88%;
References
-
Multimodality Helps Unimodality: Cross-Modal Few-Shot Learning with Multimodal Models