CLIP | 文图连接预训练模型

CLIP: Contrastive LanguageImage Pre-training



Abstract

  • 连接文本和图像的预训练模型

Contributions

  • zero-shot classification
  • ConVIRT
  • Contrastive learning

Methodology

Overview

(1) Contrastive pre-training

  • 模型架构分为两部分,图像编码器和文本编码器,图像编码器可以是比如 resnet50,然后文本编码器可以是 transformer。
  • 训练数据是网络社交媒体上搜集的图像文本对。在训练阶段,对于一个batch 的数据,首先通过文本编码器和图像编码器,得到文本和图像的特征,接着将所有的文本和图像特征分别计算内积,就能得到一个矩阵,然后从图像的角度看,行方向就是一个分类器,从文本角度看,列方向也是一个分类器。
  • 而由于我们已经知道一个batch中的文本和图像的匹配关系,所以目标函数就是最大化同一对图像和文本特征的内积,也就是矩阵对角线上的元素,而最小化与不相关特征的内积。文章的作者从社交媒体上搜集了有大约4亿对的数据。

(2) (3) Downstream inference

在下游任务测试时,有两种使用CLIP的方法。

​ 第一种,利用文本prompt进行预测,将预测的embedding同类别的embedding进行相似度匹配,实现分类任务;在测试阶段,可以直接将训练好的CLIP用于其他数据集而不需要finetune。和训练阶段类似,首先将需要分类的图像经过编码器得到特征,然后对于目标任务数据集的每一个标签,或者你自己定义的标签,都构造一段对应的文本,如上图中的 dog 会改造成 "A photo of a dog",以此类推 。然后经过编码器得到文本和图像特征,接着将文本特征与图像特征做内积,内积最大对应的标签就是图像的分类结果。这就完成了目标任务上的 zero-shot 分类。

python 复制代码
"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
#cifar每个类别,输入图片,检索匹配的类别

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

"""
Top predictions:
           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%
"""

​ 第二种,额外训练linear probe进行预测。通过CLIP的image_encoder得到视觉向量,结合标签做Logistic Regression

python 复制代码
"Ref:https://github.com/openai/CLIP"
import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)

def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) # c自定义
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

Experiments

Conclusions

  • CLIP 可以说是开辟了 CV+NLP 的多模态表征学习新时代。后面谷歌的ALIGN,微软的Florence,商汤 DeCLIP,快手 EfficientCLIP 都是研究相类似的任务。虽然 CLIP 在小部分任务上 zero-shot 精度一般,但是 CLIP 在多模态的 Encoders 能提供简单而又强大的视觉先验的表征能力。
  • CLIP和BERT、GPT、ViT的区别在于,CLIP是多模态的,包含图像处理以及文本处理两个方面内容,而BERT、GPT是单文本模态的,ViT是单图像模态的

Limitations

  • 不是和SOTA的比较:以上的数据分析,都是和a linear classifier on top of ResNet-50 features进行比较,大部分的数据集,都有对应的SOTA模型。为了达到SOTA,zero-shot CLIP估计要提高1000x的算力,当前情况不支持;
  • 在部分fine-grained分类上表现不佳: a. 前面实验分析发现,模型不能很好的区分cars,species of flowers, 以及variants of aircraft; b. abstract和systematic任务表现不好,比如统计图上object的数量; c. 在训练集中基本不会出现的比较novel的任务,表现欠佳,比如classifying the distance to the nearest car in a photo;
  • 训练集中没有出现的图片类型(out-of-distribution),表现不好,比如OCR识别数字效果可以,但是MNIST的准确率只有88%;

References

相关推荐
yongui4783427 分钟前
MATLAB的指纹识别系统实现
算法
高山上有一只小老虎28 分钟前
翻之矩阵中的行
java·算法
jghhh0136 分钟前
RINEX文件进行卫星导航解算
算法
爱思德学术1 小时前
中国计算机学会(CCF)推荐学术会议-A(计算机科学理论):LICS 2026
算法·计算机理论·计算机逻辑
CVHub1 小时前
多模态图文训推一体化平台 X-AnyLabeling 3.0 版本正式发布!首次支持远程模型推理服务,并新增 Qwen3-VL 等多款主流模型及诸多功能特性,等
算法
hoiii1871 小时前
MATLAB实现Canny边缘检测算法
算法·计算机视觉·matlab
qq_430855881 小时前
线代第二章矩阵第四课:方阵的幂
算法·机器学习·矩阵
roman_日积跬步-终至千里2 小时前
【计算机设计与算法-习题2】动态规划应用:矩阵乘法与钢条切割问题
算法·矩阵·动态规划
kupeThinkPoem2 小时前
计算机算法导论第三版算法视频讲解
数据结构·算法
sali-tec2 小时前
C# 基于halcon的视觉工作流-章67 深度学习-分类
开发语言·图像处理·人工智能·深度学习·算法·计算机视觉·分类