中秋元素之文搜图

前言

临近佳节,也要努力学习哦。

大家都知道中秋有很多节日元素,像玉兔,月亮,蟾蜍,吃螃蟹,月饼等,为了能够准确地找出语义相关的图片,本文构建了一个简易的双塔深度学习模型,使用自然语言来搜索相关的图像。该模型地主要架构思路是联合训练视觉编码器和文本编码器,将图像及其文本内容的表示投影到相同的嵌入空间中,使得文本嵌入在所描述的图像的嵌入附近,最后通过计算向量相似度返回 topk 个图片即可。

基础

  • tensorflow
  • python
  • nlp
  • cv

数据

我们本次使用的是 MS-COCO 数据集,因为我手头只有这些国外的数据集,就凑活用吧,意思到位就行了,国内真正包含中秋元素的图文对数据集估计还没有。该数据集包含超过 82000 张图像,每张图像至少有 5 个不同的文本描述,经过整理一一对应,也就是说图文对达到了 40 多万,然后分好训练集和测试集即可。

视觉编码器

这里介绍视觉编码器模型 vision_encoder ,该模型主要用于将输入的图像数据编码为特征表示。我们这里选用了预训练大模型 Xception 作为基础图像特征提取器,冻结它的参数,不进行训练,然后通过投影函数的一系列操作进行转换,包括:全连接层转换、Dropout、LayerNormalization、残差连接等步骤,来得到最终的图像特征编码,方便后续的双塔模型可以使用这些特征进行训练。

ini 复制代码
def create_vision_encoder(num_projection_layers, projection_dims, dropout_rate, trainable=False):
    xception = keras.applications.Xception(include_top=False, weights='imagenet', pooling='avg')
    for layer in xception.layers:
        layer.trainable = trainable
    inputs = layers.Input(shape=(299,299,3), name='image_input')
    xception_input = preprocess_input(inputs)
    embeddings = xception(xception_input)
    outputs = project_embeddings(embeddings, num_projection_layers, projection_dims, dropout_rate)  # [B, 128]
    return keras.Model(inputs, outputs, name='vision_encoder')

文本编码器

这里介绍文本编码器模型 text_encoder ,该模型用于将文本数据编码为特征表示。我们首先选用 bert_en_uncased_preprocess 处理器来对我们的输入文本进行处理,bert_en_uncased_preprocess 是一个 TensorFlow Hub 内置的模块,用于对文本数据进行预处理,以便满足输入到 BERT 模型中的格式。这省去了我们做分词、截断、填充、掩码、控制长度、添加特殊标记等工作。当然我们还要选用预训练大模型 bert_en_uncased 来进行文本的编码特征提取,并且也冻结了 bert 的权重参数。这里后续和上面一样要进行相同投影函数的一系列操作进行转换,结果我们可以知道每个图像和每个文本的特征维度最终会相同,这是后续进行双塔模型训练的基础。

ini 复制代码
def create_text_encoder(num_projection_layers, projection_dims, dropout_rate, trainable=False):
    bert = hub.KerasLayer( "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2", trainable=trainable )
    preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
    inputs = layers.Input(shape=(), dtype=tf.string, name='text_input')
    bert_inputs = preprocess(inputs)
    embeddings = bert(bert_inputs)['pooled_output']
    outputs = project_embeddings(embeddings, num_projection_layers, projection_dims, dropout_rate)  # [B, 128]
    return keras.Model(inputs, outputs, name='text_encoder')

双塔编码器

这里介绍双塔编码器模型,该模型用于实现文本和图像编码器的训练和测试。结构很简单,就是将两个模型都融合到继承了 keras.Model 的 DualEncoder 类中,然后在前向过程中调用 call 函数,分别求出文本对应的 caption_embeddings 和图像对应的 image_embeddings ,然后使用这两个 embedding 来进行目标 loss 的计算和梯度的更新。此模型的主要目的在于训练过程中通过最小化损失函数来学习文本嵌入和图像嵌入,使它们在嵌入空间中彼此相似。

ini 复制代码
class DualEncoder(keras.Model):
   def __init__(self, text_endocer, image_encoder, temperature=1.0, **kwargs):
       ...

   def call(self, features, training=False):
       caption_embeddings = text_encoder(features['caption'], training=training)
       image_embeddings = vision_encoder(features['image'], training=training)
       return caption_embeddings, image_embeddings

   def compute_loss(self, caption_embeddings, image_embeddings):
       logits = ( tf.matmul(caption_embeddings, image_embeddings, transpose_b=True) / self.temperature)   # [B, B]
       images_similarity = tf.matmul(image_embeddings, image_embeddings, transpose_b=True)   # [B, B]
       caption_similarity = tf.matmul(caption_embeddings, caption_embeddings, transpose_b=True)  # [B, B]
       targets = keras.activations.softmax( (caption_similarity + images_similarity) / (2 * self.temperature))  # [B, B]
       caption_loss = keras.losses.categorical_crossentropy(y_true=targets, y_pred=logits, from_logits=True)  # [B,]
       images_loss = keras.losses.categorical_crossentropy(y_true=tf.transpose(targets), y_pred=tf.transpose(logits), from_logits=True)  # [B,]
       return (caption_loss + images_loss) / 2

   def train_step(self, features):
       ...

   def test_step(self, features):
       ...

编译和训练

整个模型的优化器使用的是 AdamW ,并且定义了回调函数 early_stopping ,保证模型在 3 个 epoch 后 val_loss 都没有下降便提前停止训练。整个模型的训练相当耗时,平均一个 epoch 一个小时左右。总共训练了 4 个 epoch ,日志如下:

shell 复制代码
# Epoch 1/100
# 2760/2760 [==============================] - 3521s 1s/step - loss: 10.6863 - val_loss: 3.3784 - lr: 3.0000e-04
# Epoch 2/100
# 2760/2760 [==============================] - 3503s 1s/step - loss: 7.5042 - val_loss: 15.8172 - lr: 3.0000e-04
# Epoch 3/100
# 2760/2760 [==============================] - 3502s 1s/step - loss: 12.3186 - val_loss: 4.3896 - lr: 3.0000e-04
# Epoch 4/100
# 2760/2760 [==============================] - 3505s 1s/step - loss: 30.6067 - val_loss: 10.6790 - lr: 3.0000e-04

效果展示

经过漫长的训练,我们就检查一下模型的效果吧。我们已经有训练好的文本编码器和图像编码器了,想要通过文字搜索图片,就需要先把所有的图像通过图像编码器都转换成编码 image_embeddings 保存好,然后我们输入文本,通过文本编码器将输入的文本转换为编码 query_embedding ,然后通过两个编码的矩阵相乘计算出相似度,image_embeddings 中和 query_embedding 最相近的若干张图像,下面大家展示一下效果。

我输入一些和我们中秋元素有关的文本描述,因为这个模型本身是个演示模型,结构很简单,而且洋人数据集中包含中秋元素的内容很少,所以训练出来的效果不是很好,很多返回的图片甚至风马牛不相及,不一定是语义最相近的图片,我努力在 topN 中挑选了一些有意思的图片,大家一起开心开心得嘞。

这贼眉鼠眼的是兔子?我感觉像乌龟。

看来我家的兔子喜欢玩手机,把手机捧在脑壳里。

你这兔爷站在厕所想干啥?偷看吗?

这么小的月亮,WTF...

这个月亮更小,我差点没发现,行不行啊细狗...

好歹有一个正经月亮了,你看这个月亮,它又大又圆(哎?旋律有点熟悉的感觉)...

别说,这个螃蟹大面包还挺可爱的,不过我更喜欢旁边的熊。

果然物以类聚,人、狗、螃蟹,正不正经我还真不好说...

你两偷偷摸摸在马桶旁边密谋什么呢,能不能跟我说说。

蛤蟆也要敲代码是吧!跟我抢饭碗是吧!你瞅啥!

这个大蛤蟆有一点蛤蟆仙人的感觉了,一看就懂修炼。

蛤蟆在哪呢?不会是那个边上趴着的玩意吧......

参考

相关推荐
湫ccc5 小时前
《Opencv》基础操作详解(3)
人工智能·opencv·计算机视觉
西西弗Sisyphus7 小时前
探索多模态大语言模型(MLLMs)的推理能力
人工智能·计算机视觉·语言模型·大模型
pk_xz12345610 小时前
OpenCV实现实时人脸检测和识别
人工智能·opencv·计算机视觉
是十一月末10 小时前
Opencv实现图片和视频的加噪、平滑处理
人工智能·python·opencv·计算机视觉·音视频
MUTA️11 小时前
RT-DETR学习笔记(2)
人工智能·笔记·深度学习·学习·机器学习·计算机视觉
游客52014 小时前
opencv中的常用的100个API
图像处理·人工智能·python·opencv·计算机视觉
yusaisai大鱼15 小时前
tensorflow_probability与tensorflow版本依赖关系
人工智能·python·tensorflow
18号房客15 小时前
一个简单的深度学习模型例程,使用Keras(基于TensorFlow)构建一个卷积神经网络(CNN)来分类MNIST手写数字数据集。
人工智能·深度学习·机器学习·生成对抗网络·语言模型·自然语言处理·tensorflow
吃个糖糖16 小时前
36 Opencv SURF 关键点检测
人工智能·opencv·计算机视觉
普密斯科技17 小时前
手机外观边框缺陷视觉检测智慧方案
人工智能·计算机视觉·智能手机·自动化·视觉检测·集成测试