few shot learnning笔记

课程地址 https://youtu.be/hE7eGew4eeg?si=KBM0lY7eY_AdD8Wr
PPT地址 https://github.com/wangshusen/DeepLearning
第一节 Few-Shot Learning Basics
第二节 Siamese Network
第三节 Pretraining + Fine Tuning

(以图像识别举例)

基础

support set:数量很少的样本集合,为模型完成任务而提供更多的信息。

query:模型的输入样本,其并未出现在训练集中,其所属类别肯定包含在支持集中。

小样本学习(few shot learnning)是一种元学习(meta learning)方法。

元学习:自己学会学习(learn to learn)。

监督学习:测试样本之前并未见过(这个样本并不在训练集中),样本属于已知类(这个类别在训练集中有)。

小样本学习:查询样本之前并未见过,样本属于未知类。

k-way:支撑集有k类,k越大预测准确率越低。

n-shot:每个类别有n个样本,n越大预测准确率越高。

称呼支持集时,带上前缀k-way n-shot support set。

小样本学习的基本思想:学习一个相似度函数(similarity function)。

1、用一个很大的训练集学习一个相似度函数。

2、给定一个query样本,用相似度函数与support set判断出其所属类别(计算support set中的样本与query样本的相似度,取相似度最高的类别作为预测)

下面介绍两种方法:连体网络,预训练+微调。

siamese network(连体网络,孪生网络)

将训练集 D = { I 1 , . . . , I n } D=\{I_1, ..., I_n\} D={I1,...,In}重新构造为包含同样数量的正样本与负样本的新训练集。

正样本: ( I i , I j ) (I_i, I_j) (Ii,Ij),两张图像属于同一类,其标签为1。

负样本: ( I i , I j ) (I_i, I_j) (Ii,Ij),两张图像属于不同类,其标签为0。

给卷积神经网络 f ( ⋅ ) f(\cdot) f(⋅)输入正样本,其输出为输入的特征:
h i = f ( I i ) h_i = f(I_i) hi=f(Ii)
h j = f ( I j ) h_j = f(I_j) hj=f(Ij)
z i j = ∣ h i − h j ∣ z_{ij} = |h_i - h_j| zij=∣hi−hj∣

将 z i j z_{ij} zij输入至全连接层,得到一个标量,预测两个图片间的相似度,其中激活函数用sigmoid,输出区间为 [ 0 , 1 ] [0,1] [0,1]。

注: h , z h, z h,z为向量

计算样本的标签与预测值的损失函数,反向传播更新卷积网络与全连接网络的权重。

预测:

将支持集中的所有样本分别与query构造成一对新样本 ( q , s i ) (q,s_i) (q,si)输入给网络,预测两者相似度。

选出相似度最高的一对样本 ( q , s j ) (q,s_j) (q,sj),query便属于 s j s_j sj的类别。

triplet loss

这是另一种训练卷积神经网络的方法。

训练集中随机抽样作为一个anchor(锚点)。

从achor所属的类别中随机抽样一个样本,记为正样本。

再从除achor所属的类别外的剩余训练集中随机抽样一个样本,记为负样本。

三个样本同时输入卷积网络:
f ( x + ) , f ( x a ) , f ( x − ) f(x^+), f(x^a), f(x^-) f(x+),f(xa),f(x−)

计算特征间的距离:
d + = ∣ f ( x + − f ( x a ) ) ∣ 2 d^+ = |f(x^+ - f(x^a))|_2 d+=∣f(x+−f(xa))∣2
d − = ∣ f ( x − − f ( x a ) ) ∣ 2 d^- = |f(x^- - f(x^a))|_2 d−=∣f(x−−f(xa))∣2

注: d d d为标量

计算损失函数:

如果 d − ≥ d + + α d^- \geq d^+ + \alpha d−≥d++α,则 l o s s = 0 loss=0 loss=0;否则 l o s s = d + + α − d − loss=d^+ + \alpha - d^- loss=d++α−d−。

其中, α > 0 \alpha > 0 α>0。

注:上述就是max函数。

补充:卷积神经网络将图片映射为特征空间中的一点

通过损失函数,训练出卷积网络 f ( ⋅ ) f(\cdot) f(⋅)。

预测:

通过 f ( ⋅ ) f(\cdot) f(⋅)计算query与所有支持集样本的特征,然后计算出query与所有支持集样本的距离 d ( q , s i ) d(q,s_i) d(q,si)。

选出距离最近的一对样本 ( q , s j ) (q,s_j) (q,sj),query便属于 s j s_j sj的类别。

预训练与微调

两个向量的cosine similarity: c o s θ = x T w / ( ∣ x ∣ 2 ⋅ ∣ w ∣ 2 ) cos \theta = x^Tw/(|x|_2 \cdot |w|_2) cosθ=xTw/(∣x∣2⋅∣w∣2)。

如果 x , w x,w x,w为单位向量,则 c o s θ = x T w cos \theta = x^Tw cosθ=xTw,即相似度为两个单位向量内积。

单位向量内积较大表示,一个向量在另一个向量上的投影较大。

softmax函数:将向量 ϕ \phi ϕ映射为一个概率分布 p p p。
p i > 0 , Σ i p i = 1 p_i > 0, \Sigma_i p_i = 1 pi>0,Σipi=1.

预训练(pretraining)一个卷积网络用于特征提取。

注:特征提取也称embeding。

(以3-way 2-shot support set举例)

支撑集每个类别的特征向量为该类所有样本的特征向量求平均,再进行归一化,得到 μ 1 , μ 2 , μ 3 \mu_1, \mu_2, \mu_3 μ1,μ2,μ3。

同理,对query也做同样的步骤,得到 q q q

矩阵 M = [ μ 1 μ 2 μ 3 ] M = \begin{bmatrix} \mu_1 \\ \mu_2 \\ \mu_3\end{bmatrix} M= μ1μ2μ3

计算一个三维向量 p = s o f t m a x ( M q ) p = softmax(Mq) p=softmax(Mq)
p p p中最大元素为 p i p_i pi,则 q q q属于支持集中的 i i i类

微调

在卷积网络后面接一层全连接: p = s o f t m a x ( W ⋅ f ( x ) + b ) p = softmax(W \cdot f(x) + b) p=softmax(W⋅f(x)+b)

之前不进行微调,即 W = M , b = 0 W=M, b=0 W=M,b=0,而学习 W , b W,b W,b称为微调。

损失函数: l o s s = c o r s s E n t r o p y ( y , p ) loss = corssEntropy(y, p) loss=corssEntropy(y,p)

其中, y y y为真实标签,是一个one hot向量。

反向传播,训练出 W , b W,b W,b

介绍三个小技巧:

1、初始化权重 W = M , b = 0 W=M, b=0 W=M,b=0

2、为防止过拟合,加入正则化。(下面介绍熵正则化)

一个query预测结果为 p = s o f t m a x ( W ⋅ f ( x ) + b ) p = softmax(W \cdot f(x) + b) p=softmax(W⋅f(x)+b)

熵 H ( p ) = − Σ i p i l o g ( p i ) H(p)=-\Sigma_i p_i log(p_i) H(p)=−Σipilog(pi)

熵正则化为所有query的熵的平均。
我们希望熵正则化越小越好。

3、结合余弦相似度与softmax分类器,即先将 W W W的行向量与 f ( x ) f(x) f(x)作归一化,再计算矩阵向量乘积。
p = s o f t m a x ( W ⋅ f ( x ) + b ) p = softmax(W \cdot f(x) + b) p=softmax(W⋅f(x)+b)

相关推荐
LeonNo111 小时前
ElasticSearch学习了解笔记
笔记·学习·elasticsearch
小志biubiu2 小时前
【C++11】可变参数模板/新的类功能/lambda/包装器--C++
开发语言·c++·笔记·学习·c++11·c11
布说在见2 小时前
我的创作之路:机缘、收获、日常与未来的憧憬
经验分享·笔记
cnftv4 小时前
中国山水画在他这里暂停拐了个弯儿再重启
笔记·其他·生活·娱乐·媒体
maknul10 小时前
【学习笔记】AD智能PDF导出(装配文件)
笔记·学习·pdf
pq113_611 小时前
ftdi_sio应用学习笔记 4 - I2C
笔记·学习·linux驱动·ftdi_sio
快乐飒男12 小时前
Linux基础05
linux·笔记·学习
田梓燊13 小时前
湘潭大学软件工程算法设计与分析考试复习笔记(六)
笔记·算法·软件工程
网安墨雨13 小时前
网络安全笔记
网络·笔记·web安全
南东山人13 小时前
关于内核编程的一些笔记
linux·笔记