CLIP | 文图连接预训练模型

CLIP: Contrastive LanguageImage Pre-training



Abstract

  • 连接文本和图像的预训练模型

Contributions

  • zero-shot classification
  • ConVIRT
  • Contrastive learning

Methodology

Overview

(1) Contrastive pre-training

  • 模型架构分为两部分,图像编码器和文本编码器,图像编码器可以是比如 resnet50,然后文本编码器可以是 transformer。
  • 训练数据是网络社交媒体上搜集的图像文本对。在训练阶段,对于一个batch 的数据,首先通过文本编码器和图像编码器,得到文本和图像的特征,接着将所有的文本和图像特征分别计算内积,就能得到一个矩阵,然后从图像的角度看,行方向就是一个分类器,从文本角度看,列方向也是一个分类器。
  • 而由于我们已经知道一个batch中的文本和图像的匹配关系,所以目标函数就是最大化同一对图像和文本特征的内积,也就是矩阵对角线上的元素,而最小化与不相关特征的内积。文章的作者从社交媒体上搜集了有大约4亿对的数据。

(2) (3) Downstream inference

在下游任务测试时,有两种使用CLIP的方法。

​ 第一种,利用文本prompt进行预测,将预测的embedding同类别的embedding进行相似度匹配,实现分类任务;在测试阶段,可以直接将训练好的CLIP用于其他数据集而不需要finetune。和训练阶段类似,首先将需要分类的图像经过编码器得到特征,然后对于目标任务数据集的每一个标签,或者你自己定义的标签,都构造一段对应的文本,如上图中的 dog 会改造成 "A photo of a dog",以此类推 。然后经过编码器得到文本和图像特征,接着将文本特征与图像特征做内积,内积最大对应的标签就是图像的分类结果。这就完成了目标任务上的 zero-shot 分类。

python 复制代码
"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
#cifar每个类别,输入图片,检索匹配的类别

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

"""
Top predictions:
           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%
"""

​ 第二种,额外训练linear probe进行预测。通过CLIP的image_encoder得到视觉向量,结合标签做Logistic Regression

python 复制代码
"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)

def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) # c自定义
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

Experiments

Conclusions

  • CLIP 可以说是开辟了 CV+NLP 的多模态表征学习新时代。后面谷歌的ALIGN,微软的Florence,商汤 DeCLIP,快手 EfficientCLIP 都是研究相类似的任务。虽然 CLIP 在小部分任务上 zero-shot 精度一般,但是 CLIP 在多模态的 Encoders 能提供简单而又强大的视觉先验的表征能力。
  • CLIP和BERT、GPT、ViT的区别在于,CLIP是多模态的,包含图像处理以及文本处理两个方面内容,而BERT、GPT是单文本模态的,ViT是单图像模态的

Limitations

  • 不是和SOTA的比较:以上的数据分析,都是和a linear classifier on top of ResNet-50 features进行比较,大部分的数据集,都有对应的SOTA模型。为了达到SOTA,zero-shot CLIP估计要提高1000x的算力,当前情况不支持;
  • 在部分fine-grained分类上表现不佳: a. 前面实验分析发现,模型不能很好的区分cars,species of flowers, 以及variants of aircraft; b. abstract和systematic任务表现不好,比如统计图上object的数量; c. 在训练集中基本不会出现的比较novel的任务,表现欠佳,比如classifying the distance to the nearest car in a photo;
  • 训练集中没有出现的图片类型(out-of-distribution),表现不好,比如OCR识别数字效果可以,但是MNIST的准确率只有88%;

References

相关推荐
Captain823Jack1 小时前
nlp新词发现——浅析 TF·IDF
人工智能·python·深度学习·神经网络·算法·自然语言处理
Captain823Jack2 小时前
w04_nlp大模型训练·中文分词
人工智能·python·深度学习·神经网络·算法·自然语言处理·中文分词
是小胡嘛3 小时前
数据结构之旅:红黑树如何驱动 Set 和 Map
数据结构·算法
m0_748255023 小时前
前端常用算法集合
前端·算法
呆呆的猫3 小时前
【LeetCode】227、基本计算器 II
算法·leetcode·职场和发展
Tisfy3 小时前
LeetCode 1705.吃苹果的最大数目:贪心(优先队列) - 清晰题解
算法·leetcode·优先队列·贪心·
余额不足121383 小时前
C语言基础十六:枚举、c语言中文件的读写操作
linux·c语言·算法
火星机器人life6 小时前
基于ceres优化的3d激光雷达开源算法
算法·3d
虽千万人 吾往矣6 小时前
golang LeetCode 热题 100(动态规划)-更新中
算法·leetcode·动态规划
arnold667 小时前
华为OD E卷(100分)34-转盘寿司
算法·华为od