Distilling the Knowledge in a Neural Network学习笔记

1.主要内容是什么:

这篇论文介绍了一种有效的知识迁移方法------蒸馏,可以将大型模型中的知识转移到小型模型中,从而提高小型模型的性能。这种方法在实际应用中具有广泛的潜力,并且可以应用于各种不同的任务和领域。

论文中首先介绍了蒸馏的基本原理。大型模型通常通过softmax输出层产生类别概率,而蒸馏则通过提高softmax的温度来产生更软化的概率分布。在蒸馏过程中,使用大型模型生成的高温软目标分布来训练小型模型,以实现知识的迁移。

2.怎么实现的?

具体实现方式是,

3.硬标签和软目标?

硬标签和软目标是知识蒸馏方法中的两种不同的目标函数。

硬标签是指使用真实的标签作为目标进行训练。在传统的监督学习中,通常使用硬标签来训练模型,即将模型的输出与真实标签进行比较,通过最小化它们之间的差异来优化模型。

软目标是指使用大型模型生成的概率分布作为目标进行训练。

在知识蒸馏中,大型模型生成的概率分布被认为是一种"软"的目标,因为它们比硬标签更平滑,包含了更多的信息。小型模型通过最小化其输出与软目标之间的差异来训练。 在论文中,作者发现将硬标签和软目标结合起来训练蒸馏模型可以取得更好的效果。他们提出了一种加权平均的目标函数,其中第一个目标函数是使用软目标进行的交叉熵损失,第二个目标函数是使用硬标签进行的交叉熵损失。通过调整这两个目标函数的权重,可以在保留软目标的信息的同时,让模型更好地学习硬标签的知识。 在使用硬标签和软目标进行训练时,需要注意将软目标的梯度乘以温度的平方,以保持硬目标和软目标的相对贡献大致不变。这是因为软目标的梯度与温度的平方成反比,所以在使用硬目标和软目标时需要进行调整,以保持相对的平衡。

总的来说,硬标签和软目标是知识蒸馏方法中两种不同的目标函数,通过结合它们可以在训练蒸馏模型时获得更好的效果。

4.为什么不把软目标当做唯一loss?

软目标是大模型的输出概率分布传递给小模型来获得的,这些概率分布可能包含了大模型的对于小模型来说,额外的知识和不确定性。

而且,软目标并不是完全准确的标签,因此仅依赖软目标进行训练可能会导致模型过度拟合软目标的噪声。

所以,通常会将软目标与真实标签结合起来,使用两个不同的目标函数进行训练。

第一个目标函数是使用软目标计算的交叉熵损失

第二个目标函数是使用真实标签计算的交叉熵损失。

这样可以在保留软目标的知识的同时,确保模型也能够学习到正确的标签信息。

另外,软目标通常是通过使用较高的温度参数来生成的,这可以使概率分布更加平滑。

但在训练过程中,为了使得软目标和硬目标的相对贡献保持大致不变,需要将软目标的梯度乘以温度的平方。这样可以确保在尝试不同的温度参数时,硬目标和软目标的相对贡献保持一致。

5.软目标具体是用的大模型的哪一层呢?

软目标是使用大模型的softmax层输出的概率分布。

软目标是通过将大模型的logits(输入softmax的值)通过softmax函数计算得到的概率分布。这个概率分布可以被视为大模型对每个类别的预测概率。

6.在做蒸馏的时候,需要大模型和小模型的模型输出头保持结构一致吗

是的。大模型和小模型的模型输出头(output head)需要保持结构一致。

两个模型的输出头应该具有相同的类别数目,并且在相同的类别顺序上产生预测结果。这是因为在蒸馏过程中,我们希望小模型能够学习大模型的知识,包括类别之间的关系和概率分布。如果大模型和小模型的输出头结构不一致,那么在计算损失函数和梯度时会出现不匹配的情况,导致蒸馏效果不佳。因此,在进行蒸馏时,需要确保大模型和小模型的输出头具有相同的结构。

相关推荐
萌新小码农‍5 小时前
Spring框架学习day7--SpringWeb学习(概念与搭建配置)
学习·spring·状态模式
蓝婷儿5 小时前
6个月Python学习计划 Day 15 - 函数式编程、高阶函数、生成器/迭代器
开发语言·python·学习
行云流水剑5 小时前
【学习记录】深入解析 AI 交互中的五大核心概念:Prompt、Agent、MCP、Function Calling 与 Tools
人工智能·学习·交互
love530love5 小时前
【笔记】在 MSYS2(MINGW64)中正确安装 Rust
运维·开发语言·人工智能·windows·笔记·python·rust
一弓虽6 小时前
zookeeper 学习
分布式·学习·zookeeper
苗老大6 小时前
MMRL: Multi-Modal Representation Learning for Vision-Language Models(多模态表示学习)
人工智能·学习·语言模型
xhyu616 小时前
【学习笔记】On the Biology of a Large Language Model
笔记·学习·语言模型
小白杨树树6 小时前
【SSM】SpringMVC学习笔记7:前后端数据传输协议和异常处理
笔记·学习
海棠蚀omo7 小时前
C++笔记-C++11(一)
开发语言·c++·笔记