few shot learnning笔记

复制代码
课程地址 https://youtu.be/hE7eGew4eeg?si=KBM0lY7eY_AdD8Wr
PPT地址 https://github.com/wangshusen/DeepLearning
第一节 Few-Shot Learning Basics
第二节 Siamese Network
第三节 Pretraining + Fine Tuning

(以图像识别举例)

基础

support set:数量很少的样本集合,为模型完成任务而提供更多的信息。

query:模型的输入样本,其并未出现在训练集中,其所属类别肯定包含在支持集中。

小样本学习(few shot learnning)是一种元学习(meta learning)方法。

元学习:自己学会学习(learn to learn)。

监督学习:测试样本之前并未见过(这个样本并不在训练集中),样本属于已知类(这个类别在训练集中有)。

小样本学习:查询样本之前并未见过,样本属于未知类。

k-way:支撑集有k类,k越大预测准确率越低。

n-shot:每个类别有n个样本,n越大预测准确率越高。

称呼支持集时,带上前缀k-way n-shot support set。

小样本学习的基本思想:学习一个相似度函数(similarity function)。

1、用一个很大的训练集学习一个相似度函数。

2、给定一个query样本,用相似度函数与support set判断出其所属类别(计算support set中的样本与query样本的相似度,取相似度最高的类别作为预测)

下面介绍两种方法:连体网络,预训练+微调。

siamese network(连体网络,孪生网络)

将训练集 D = { I 1 , . . . , I n } D=\{I_1, ..., I_n\} D={I1,...,In}重新构造为包含同样数量的正样本与负样本的新训练集。

正样本: ( I i , I j ) (I_i, I_j) (Ii,Ij),两张图像属于同一类,其标签为1。

负样本: ( I i , I j ) (I_i, I_j) (Ii,Ij),两张图像属于不同类,其标签为0。

给卷积神经网络 f ( ⋅ ) f(\cdot) f(⋅)输入正样本,其输出为输入的特征:
h i = f ( I i ) h_i = f(I_i) hi=f(Ii)
h j = f ( I j ) h_j = f(I_j) hj=f(Ij)
z i j = ∣ h i − h j ∣ z_{ij} = |h_i - h_j| zij=∣hi−hj∣

将 z i j z_{ij} zij输入至全连接层,得到一个标量,预测两个图片间的相似度,其中激活函数用sigmoid,输出区间为 [ 0 , 1 ] [0,1] [0,1]。

注: h , z h, z h,z为向量

计算样本的标签与预测值的损失函数,反向传播更新卷积网络与全连接网络的权重。

预测:

将支持集中的所有样本分别与query构造成一对新样本 ( q , s i ) (q,s_i) (q,si)输入给网络,预测两者相似度。

选出相似度最高的一对样本 ( q , s j ) (q,s_j) (q,sj),query便属于 s j s_j sj的类别。

triplet loss

这是另一种训练卷积神经网络的方法。

训练集中随机抽样作为一个anchor(锚点)。

从achor所属的类别中随机抽样一个样本,记为正样本。

再从除achor所属的类别外的剩余训练集中随机抽样一个样本,记为负样本。

三个样本同时输入卷积网络:
f ( x + ) , f ( x a ) , f ( x − ) f(x^+), f(x^a), f(x^-) f(x+),f(xa),f(x−)

计算特征间的距离:
d + = ∣ f ( x + − f ( x a ) ) ∣ 2 d^+ = |f(x^+ - f(x^a))|_2 d+=∣f(x+−f(xa))∣2
d − = ∣ f ( x − − f ( x a ) ) ∣ 2 d^- = |f(x^- - f(x^a))|_2 d−=∣f(x−−f(xa))∣2

注: d d d为标量

计算损失函数:

如果 d − ≥ d + + α d^- \geq d^+ + \alpha d−≥d++α,则 l o s s = 0 loss=0 loss=0;否则 l o s s = d + + α − d − loss=d^+ + \alpha - d^- loss=d++α−d−。

其中, α > 0 \alpha > 0 α>0。

注:上述就是max函数。

补充:卷积神经网络将图片映射为特征空间中的一点

通过损失函数,训练出卷积网络 f ( ⋅ ) f(\cdot) f(⋅)。

预测:

通过 f ( ⋅ ) f(\cdot) f(⋅)计算query与所有支持集样本的特征,然后计算出query与所有支持集样本的距离 d ( q , s i ) d(q,s_i) d(q,si)。

选出距离最近的一对样本 ( q , s j ) (q,s_j) (q,sj),query便属于 s j s_j sj的类别。

预训练与微调

两个向量的cosine similarity: c o s θ = x T w / ( ∣ x ∣ 2 ⋅ ∣ w ∣ 2 ) cos \theta = x^Tw/(|x|_2 \cdot |w|_2) cosθ=xTw/(∣x∣2⋅∣w∣2)。

如果 x , w x,w x,w为单位向量,则 c o s θ = x T w cos \theta = x^Tw cosθ=xTw,即相似度为两个单位向量内积。

单位向量内积较大表示,一个向量在另一个向量上的投影较大。

softmax函数:将向量 ϕ \phi ϕ映射为一个概率分布 p p p。
p i > 0 , Σ i p i = 1 p_i > 0, \Sigma_i p_i = 1 pi>0,Σipi=1.

预训练(pretraining)一个卷积网络用于特征提取。

注:特征提取也称embeding。

(以3-way 2-shot support set举例)

支撑集每个类别的特征向量为该类所有样本的特征向量求平均,再进行归一化,得到 μ 1 , μ 2 , μ 3 \mu_1, \mu_2, \mu_3 μ1,μ2,μ3。

同理,对query也做同样的步骤,得到 q q q

矩阵 M = [ μ 1 μ 2 μ 3 ] M = \begin{bmatrix} \mu_1 \\ \mu_2 \\ \mu_3\end{bmatrix} M= μ1μ2μ3

计算一个三维向量 p = s o f t m a x ( M q ) p = softmax(Mq) p=softmax(Mq)
p p p中最大元素为 p i p_i pi,则 q q q属于支持集中的 i i i类

微调

在卷积网络后面接一层全连接: p = s o f t m a x ( W ⋅ f ( x ) + b ) p = softmax(W \cdot f(x) + b) p=softmax(W⋅f(x)+b)

之前不进行微调,即 W = M , b = 0 W=M, b=0 W=M,b=0,而学习 W , b W,b W,b称为微调。

损失函数: l o s s = c o r s s E n t r o p y ( y , p ) loss = corssEntropy(y, p) loss=corssEntropy(y,p)

其中, y y y为真实标签,是一个one hot向量。

反向传播,训练出 W , b W,b W,b

介绍三个小技巧:

1、初始化权重 W = M , b = 0 W=M, b=0 W=M,b=0

2、为防止过拟合,加入正则化。(下面介绍熵正则化)

一个query预测结果为 p = s o f t m a x ( W ⋅ f ( x ) + b ) p = softmax(W \cdot f(x) + b) p=softmax(W⋅f(x)+b)

熵 H ( p ) = − Σ i p i l o g ( p i ) H(p)=-\Sigma_i p_i log(p_i) H(p)=−Σipilog(pi)

熵正则化为所有query的熵的平均。
我们希望熵正则化越小越好。

3、结合余弦相似度与softmax分类器,即先将 W W W的行向量与 f ( x ) f(x) f(x)作归一化,再计算矩阵向量乘积。
p = s o f t m a x ( W ⋅ f ( x ) + b ) p = softmax(W \cdot f(x) + b) p=softmax(W⋅f(x)+b)

相关推荐
兴趣使然_2 小时前
【笔记】使用 html 创建网址快捷方式
笔记·html·js
aramae3 小时前
C++ -- STL -- vector
开发语言·c++·笔记·后端·visual studio
fen_fen4 小时前
学习笔记(32):matplotlib绘制简单图表-数据分布图
笔记·学习·matplotlib
饕餮争锋8 小时前
设计模式笔记_创建型_建造者模式
笔记·设计模式·建造者模式
萝卜青今天也要开心8 小时前
2025年上半年软件设计师考后分享
笔记·学习
吃货界的硬件攻城狮8 小时前
【STM32 学习笔记】SPI通信协议
笔记·stm32·学习
蓝染yy9 小时前
Apache
笔记
lxiaoj11110 小时前
Python文件操作笔记
笔记·python
半导体守望者10 小时前
ADVANTEST R4131 SPECTRUM ANALYZER 光谱分析仪
经验分享·笔记·功能测试·自动化·制造
啊我不会诶11 小时前
倍增法和ST算法 个人学习笔记&代码
笔记·学习·算法