论文学习——基于Whisper迁移学习的阿尔兹海默症检测方法——音频特征和语义特征的结合

文章目录

引言

  • 目前来说基于音频检测AD的方法主要分为四类,分别是

    • 基于副语言特征,包括常见的IS09、IS10等特征集,直接提取,然后做分类。
    • 基于声学特征,这类方法一般是常见的mel频谱图之类的特征集,或者自定义的提取的特征,然后再做一个分类
    • 基于语音预训练模型,这一类模型实在很多种语音文本上进行了训练,然后提取出对应的特征进行分类。
    • 基于语言学特征,将音频转录为文本之后,使用预训练语言模型进行特征提取和分类。
  • 但是如果要做小样本的检测,甚至说零样本的检测的话,上述几种方法都比较困难,无论是那种方法都需要一定规模的数据集进行微调,而大语言模型再说明完任务信息后,能够实现上述效果。但是检测效果收到了ASR转录的准确性的限制!除此之外,监测效果还收到患者主观性的影响,所以需要音频特征作为补充。基于此,在搜集最新的AD检测的相关论文,能够做到小样本或者零样本最好!

  • 目前所使用的数据集都是基于图片描述的,所以需要理解患者对于图片的描述内容,或者语义特征,仅仅依靠音频特征是不可取的!

  • 这篇文章是基于whisper的迁移学习的方法,是音频特征和语义特征的一个结合!好好整理一下!

正文------基于Whisper迁移学习的阿尔茨海默病检测方法(使用转述文本作为提示,利用音频段落进行分类)

摘要

  • 阿尔茨海默病是一个退行性疾病,唯有尽早诊断才能尽可能有效治疗,通过音频检测是一个有效的方式。
  • 本文核心创新点 是:
    • 使用最新的跨语言语音识别和翻译模型------Whisper,进行迁移学习,解决阿尔茨海默病检测的可行性
    • Whisper的缺陷和解决办法
      • Whisper模型在微调时严格限制了音频的长度,导致了其所能够接受的数据长度有限
      • 在训练过程中,提供整段音频完整的转路文本作为Prompt,解决上述问题
  • 实验效果
    • 在ADReSSo数据集上准确率是84.51%,比Baseline要高,还超过了一众传统的级联模型,这说明这种方法的有效性。

1 Introduction介绍

  • AD是一个隐匿发病的退行性神经疾病。随着老年人口数量的增加,AD患者数量急剧增加,全球AD患者已经达到了五千万左右(22年的统计结果是九千万)。AD患者大脑一旦产生病理性变化,会出现认知衰退以及表达恶化。医学研究者已经证明,唯有尽早诊断才能减缓AD的恶化。因此,对于任何一个有罹患该病症可能的人,进行AD检测和认知检测是十分必要的。
  • 最近,已经有很多基于语音检测AD的研究在开展 ,有望成为促进大范围AD检测基础。目前最常见的方法是,从患者的自发语音中提取患者语义特征或者音频特征,然后使用不同的分类器或者深度神经网络进行学习,尤其是使用预训练模型,能够显著提升模型的检测效果。
    • 这里提到了三篇使用预训练语言模型的文章,一个是使用PLM进行指令微调, 一个是使用GPT-3进行特征提取,还有一个是使用wav2vec提取特征,前两个基本上都看过了,我最新发的文章也是前两个的结合。
  • 但是,在本文中,我们使用的是不同的方法
    • 我们仅仅基于语音识别模型Whisper,并使用迁移学习实现AD患者的检测
    • 除此之外,我们创新性的提出了全文本提示的方法,也就是FTP,这能是模型获取到全局信息,而不仅仅是局限在某一个部分。
  • 实验结果证明了迁移学习和Whisper结合的实用性,以及FTP的有效性。
ADReSSo 数据集
  • 这个数据集是老常客了,就简单提一下,是InterSpeech2021比赛的官方数据集。比赛官网链接
  • 本文是将所有的音频数据将采样变成了16 kHz
Whisper模型
  • Whisper模型是一个通用跨语言的语音识别和翻译模型,是基于Transformer模型架构的。主要是由五个不同尺寸的模型构成,具体参数如下
  • 这个模型已经在不同的数据集上训练过了,总共是680,000个小时的监督训练,使其能够很好地执行多种任务,包括多语言语音识别、语音翻译、语言识别以及声音事件检测等。Whsiper已经在很多公开的数据集上,在上述任务中取得了很好的成绩。除此之外,研究发现,Whisper模型生成的音频表示具有很强的鲁棒性,能够很好地适应音频分类任务。进一步说,相关研究已经能够证明通过迁移学习,whisper能够有效应用在语音理解功能上。这使得使用whisper的长处去解决声音领域的不同问题提供了可能性!

  • 在训练和解码过程中,whisper需要音频被填充或者缩减到30s的长度。虽然对于长文本音频,提供了利用先前片段提供的上下文进行连续转录的接口。然而,当微调whisper进行迁移学习的时候,音频长度还是限制在30以内,没有办法进行解决!

2 Methods方法

Audio Processing音频预处理
  • 在ADReSSo数据集上,大部分的音频记录都是超过了30s。因此,预处理是必须的,需要将音频裁剪到合适的长度才行。我们利用Whisper-Large模型将长音频进行翻译转录,在这个过程中会将音频根据停顿或者分词将需要转录的音频划分成不同的段落 ,称之为解码段落,我们会记录每一段落在整体音频中的结束时间,记录为T1,T2,T3等。
  • 为了最大化迁移学习过程中每一段的信息总量,我们将所有的解码段落 合并成一个更大的段落,记为Merged Segment合并段落。我们使用贪心的方式,在序列【0,T1,T2,T3,...,Tn】中合并尽可能长的段落,但是总的时长是小于30s 的。如果最后的长度小于30s,直接丢弃。
  • 最后,我们使用这些合并的段落进行Whisper迁移学习中的训练和测试。

这里画了一个图,大概的思路如下,本质上还是组装切片

Transfer Learning(TL)迁移学习
  • 迁移学习能够将在一个领域学习到的知识,迁移应用到另外一个领域中,并且能够实现将预训练模型应用到不同的任务中。因为AD检测数据集都比较小,所以迁移学习在其他大规模数据集上训练过的模型,进行微调解决AD检测问题十分有价值。例如,之前有研究就是使用无监督的预训练语言模型wav2vec提取特征,然后将之输入到下游模型中,完成AD检测任务。但是,这些方法一般在预训练模型后需要加上一个分类器,来解决两个领域之间的差异问题 。相反,当我们使用无监督的大规模模型比如说Whisper进行迁移学习的时候,可以直接将标签转换为文本,然后将之作为微调的目标,无需额外的层和分类器!
    • 这里解释一下
      • Whisper模型在训练的过程中的目标,将输入的音频直接转换为相应的文本,在AD检测任务中,标签也可以直接被转换为文本格式,作为微调模型的目标。
      • 相当于就是换了一个输出,调整他的映射关系
  • 在这个文章中,我们使用从训练集中获得合并序列的数据集,微调Whisper模型。为了尽可能保留原始模型的知识,在微调过程中,我们选择冻结Whisper模型的编码层 。每一个段落的目标是一个的文本序列:固定是以ADReSSo Classification作为开头,后续跟着一个Alzheimer或者Normal的标签,每一个标签都表示一个不同的状态。在测试每一个合并的序列(merged segment)时,我们也会使用"ADReSSo Classification"作为解码器的前缀,然后要求的Whisper模型预测下一个序列的标签。
  • 一段音频的所有的merger segment的预测结果都会他投票聚合到一起,决定当前音频样例的最终检测结果。我们称这个方法为Whisper-TL,这个方法在训练和测试过程中,并没有使用任何ASR转录文本。

具体训练过程见下图

Use Full Transcripts as Prompt(FTP)
  • 虽然仅仅使用音频对Whisper模型进行迁移学习是可行的,但是Whisper模型只能接受30s的数据,使得该模型所能够接受的信息有限。为了解决这个限制,我们提出了一个新的方法,就是在微调每一段音频的时候,提供当前音频对应的转路文本作为提示 。这使得模型能够利用全局文本信息

  • 正式定义如下

    • 将音频定义为为 x x x
    • 当前音频对应的转录文本信息定义为 t t t
  • 音频 x x x在进行数据预处理之后,将其划分为 k k k段,定义为 x i x_i xi,其中 i < = k i <= k i<=k

  • 在我们的框架中,模型对于每一段的预测过程,可以使用如下的公示进行表达

    • y i y_i yi表示解码器生成字符
    • r i r_i ri表示当前样本生成的预测结果
      P ( y i ∣ t , x i ) → r i P(y_i | t,x_i) \rightarrow r_i P(yi∣t,xi)→ri
  • 为了获取最终的检测结果 r r r ,正如在上一章中做的一样,我们将每一个片段的结果进行聚合,通过投票决定最终的结果是什么,具体公示如下

r = V o t e ( r 1 , r 2 , r 3 , . . . , r k ) r = Vote(r_1,r_2,r_3,...,r_k) r=Vote(r1,r2,r3,...,rk)

  • 训练过程中prompt提示的作用

    • 在训练过程中,提示Prompt应该是作为引用或者说参考,而不是作为预测的目标,所以不能将prompt作为需要预测或者生成的内容,而是将其作为帮助模型理解上下文。
  • 提示Prompt在解码器众的位置

    • 因为Prompt提示是作为提示和参考,不是生成的目标或者提示,所以prompt位置应该在解码器中保持不变,有固定的位置,在实验中是以 < ∣ s t a r t o f p r e v ∣ > <|startofprev|> <∣startofprev∣>作为开始。
  • 提示prompt的具体作用

    • 模型在遇到<|startoftranscript|>标签之前,模型输出的和输入的内容一模一样,并不会进行预测或者生成。再遇到这个<|startoftranscript|>后,说明后续的内容是需要的学习生成的语音编码。
  • prompt在实验中的设置

    • 无论是训练集还是测试集,都需要提供对应测试样本的完整的转述文本作为提示,本文中是采用whisper-large对所有音频进行ASR转录的。
    • 除此之外,还需要提供对应的 < ∣ p r e f i x ∣ > <|prefix|> <∣prefix∣>前缀作为引导,也就是之前的< Alzhemier Classification >这个标签。
  • 上述使用完整ASR转录文本作为提示的方法称之为Whisper-TL-FTP

3 Experiment Setup实验设置

Training Details训练细节
  • 尝试的Whisper模型

    • Whisper-base、Whisper-small、Whisper-medium、Whisper-large
  • 可学习的参数

    • 冻结了编码器
    • 仅仅对解码器进行微调
  • 模型参数

    • 交叉熵损失函数
    • AdamW优化器
    • 训练 15 个 epoch
    • 批量大小为 16
    • 学习率为 0.0001
    • 权重衰减为 0.01
    • Adam 的 epsilon 值为 1.0e-8
  • 带有Prompt的训练细节

    • 文本长度限制
      • 由于长度限制,提示信息被设置为经过分词器编码后的完整转录的最后 335 个 token
    • 测试模型的选择
      • 在训练 15 个 epoch 后,模型达到了收敛,并选择了最终的检查点进行测试
    • 随机实验的设置
      • 使用不同的随机种子可能会导致不同的结果。为了提高结果的可靠性,采用10个随机种子进行10次实验。
  • 在本文中,我们的方法和以下五种方法做对比,主要分为两组。

第一组,不提供ASR转述文本进行提示,仅仅是和Whisper-TL进行比较

  • 1、ADReSSo 的BaseLine模型

    • ADReSSo系统中使用的Baseline是使用eGeMAPS特征集,后面加上一个SVM支持向量机作为下游分类器
  • 2、Wav2vec

    • 使用Wav2vec预训练模型提取出特征,然后后面使用**tree bagger(TB)**决策树打包作为下游分类器的模型

第二组,提供ASR转录文本进行提示,用来和Whisper-TL-FTP进行对比

  • 3、BaseLine中增加语义特征

    • 使用Google Cloud-based Speech Recogniser提取出语义特征,后面加上对应的SVM支持向量机进行分类,输出基于音频检测结果。
    • 使用后期融合Late Fusion ,将音频的检测结果和语义的检测结果进行融合,通过类似投票或者加权平均的方式生成最终结果。
  • 4、使用WavBERT方法

    • 使用Wav2vec2.0 作为ASR转录模型,提取出对应音频的转路文本,在的转录文本中增加对应的句子层级的停顿信息
    • 使用Bert作为下游模型,接受转录文本,进行分类。
    • 首先通过Wav2vec2.0模型从语音中提取特征并生成转录文本,随后将这些文本输入到BERT模型中,BERT负责理解文本的内容和语言特征,最终输出用于分类的特征或预测结果。
  • 5、单纯使用RoBerta进行转录文本的微调------这个就是之前新国立的Yi Wang写的那个文章所用的方法

    • 我们是用Whisper获取音频对应的转述文本,然后将之作为输入对预训练语言模型RoBerta进行微调。
    • Fairseq对模型进行微调
      • Fairseq是一个开源的序列建模工具包,广泛用于自然语言处理任务,支持多种预训练语言模型的微调和训练。在这个实验中,他们使用fairseq框架对RoBERTa进行微调。
    • 模型训练的超参数如下
      • 更新次数(update number):2000,即模型训练时更新权重的次数为2000次。
      • 批量大小(batch size):4,表示每次模型训练时同时处理4个样本。
      • 学习率(learning rate):1.0e-5,控制每次权重更新时步长的大小。
      • 权重衰减(weight decay):0.1,帮助正则化,防止模型过拟合。
      • Adam epsilon:1.0e-6,Adam优化器的epsilon值,用于防止除零错误。
      • dropout:0.1,表示训练时随机丢弃一部分神经元来防止过拟合。
    • 使用最终收敛的模型进行测试,通过投票的方式集成结果!

4 Results结果

  • 表格二总结了我们提出的方法和其他用于对比的方法,在ADReSSo数据集上的实验结果。
  • 无转述文本的实验对照组
    • 如红框标注的,在没有ASR转录文本提示的情况下,我们提出的方法的检测效果最高,能够打到77.46%,超过了其他所有的方法。
  • 有转述文本的实验对照组
    • 我们提出的方法准确率达到了84.51%,比其他很多级联模型要好很多。在不添加任何其他额外模块的情况下,我们提出的方法能够有效利用音频信息和语义信息。
    • 当使用来自Whisper的相同ASR转录文本时,其表现相较于预训练语言模型RoBERTa提升了5.3%的相对幅度,展示了Whisper-TL-FTP方法的有效性。(这里就有失偏颇了,下面是这种方法的原始论文,准确率还是很高的,最高能够达到87%)
  • 增加停顿信息的对照组
    • 此外,我们还将我们的方法与添加停顿信息的方法 进行了比较。添加停顿信息的方法涉及将音频的停顿特征编码到文本中,使下游语言模型能够学习某些语音特征,从而获得更好的结果 [4, 5]。
    • 与这种方法不同的是,我们的Whisper-TL-FTP方法在单一模型中同时学习语音和文本信息,并且还能提取有用的判断信息,取得了更好的结果
关于模型尺寸的对比
  • 在ADReSSo数据集上使用了不同尺寸的whisper模型进行迁移学习,具体实验结果可以看到表格3,见下图
  • 增加FTP对于Whisper模型的效果

    • 使用whisper-base模型的时候,增加额外的转述文本,模型的检测效果反而在下降。但是,对于其他模型而言,增加了对应的转述文本后,检测效果反而更好了。
    • 结论
      • 这表明在具有更多参数的Whisper模型中使用FTP能够帮助模型更有效地学习提示文本,从而更好地检测阿尔茨海默病。
  • 横向对比------不同参数规模的检测效果

    • 在不同模型尺寸的迁移学习效果比较中,中型模型在准确率和F1分数上取得了最佳结果。
    • 这表明,在使用Whisper模型进行迁移学习时,模型的规模应足够大以充分利用音频信息和提示文本,但也不能过于庞大,以避免显著的过拟合
    • 合适的参数规模能够帮助模型更好地平衡性能,提升检测效果。

5 Conclusion总结

  • 在这篇文章中,我们尝试了对Whisper模型进行迁移学习,解决AD检测分类问题。我们的方法,不像其他任何的方法,我们并没有添加其他额外的层作为分类器。
  • 我们提出的两个方法分别是Whisper-TL和Whisper-TL-FTP,都取得了很好的效果,甚至到达84.51%的准确率,超过了其他很多的方法。
  • 除此之外,FTP的方法,无论是在小尺寸的还是大尺寸的Whisper模型上进行迁移学习,都有9%到12%的提高和改良,这证明了我们所提出的方法的有效性。
  • 通过这个方法,我们希望能够为开发有效的AD检测方式提供有效地见解。

总结

  • 这个文章是提出了一种新的AD检测方法,是一种多模态的方式,将语义特征和音频特征进行结合,是一种多模态的检测方法。但是并没有完全解决目前的问题症结。
  • 目前的主要问题是,针对AD检测的数据集太少了,如何做到跨语言检测,或者小样本检测。目前很多比赛也表现出了解决这个问题的倾向,从2020年的InterSpeech到2024年ICASSP连续四年的比赛都是解决AD检测,已经由单语言向跨语言转变,今年的连续两届会议都是解决跨语言检测的问题。
相关推荐
初学者7.16 分钟前
Webpack学习笔记(2)
笔记·学习·webpack
创意锦囊2 小时前
随时随地编码,高效算法学习工具—E时代IDE
ide·学习·算法
尘觉3 小时前
算法的学习笔记—扑克牌顺子(牛客JZ61)
数据结构·笔记·学习·算法
1 9 J3 小时前
Java 上机实践11(组件及事件处理)
java·开发语言·学习·算法
Blankspace学3 小时前
Wireshark软件下载安装及基础
网络·学习·测试工具·网络安全·wireshark
南宫生3 小时前
力扣-图论-70【算法学习day.70】
java·学习·算法·leetcode·图论
bohu834 小时前
sentinel学习笔记1-为什么需要服务降级
笔记·学习·sentinel·滑动窗口
HE10294 小时前
威尔克斯(Wilks)分布
学习
初学者7.5 小时前
Webpack学习笔记(3)
笔记·学习·webpack
璞~6 小时前
MQTT 课程概览 (学习笔记)02
笔记·学习