机器学习(李宏毅)——Domain Adaptation

一、前言

本文章作为学习2023年《李宏毅机器学习课程》的笔记,感谢台湾大学李宏毅教授的课程,respect!!!

二、大纲

  1. 什么是Domain Adaptation?
  2. Domain Adaptation 的核心问题
  3. 重点介绍DAT

三、什么是Domain Adaptation?

Domain Adaptation的字面意思就是领域自适应,啥叫自适应呢,通俗讲就是在一个领域适应得好好的,在另外一个不同领域期待自动适应。举例而言:

  • MNIST辨识

上图手写数字辨识的例子在黑白风格训练数据上训练并在黑白风格测试数据上测试99.5%正确率,但是测试数据同样是数字只是风格变成彩色背景,分类器的分类精度直接变成57.5%。

明显的同样是数字分类问题,模型适应不了新的风格,换句话说领域不适应。这种情况称为Domain shift(领域偏移),即Training and testing data have different distributions。

四、Domain Adaptation 的核心问题

Domain Shift

Domain Shift(领域偏移)指的是在机器学习模型的训练数据(源域)和测试数据(目标域)之间存在统计分布的差异,导致模型在目标域上的性能下降。

Domain Shift(领域偏移)的三种类型:

  1. Covariate Shift(协变量偏移)
  2. Label Shift(标签偏移)
  3. Concept Shift(概念偏移)

如下图:

问了下chatgpt,我觉得通俗解释得很好,直接贴一下:


五、重点介绍DAT

课程里面重点介绍了Covariate Shift(协变量偏移)情况下的解决办法,即Domain Adversarial Training(对抗训练)。

对于目标领域的样本情况可以划分为以下4种:

  • 数据虽少但有标注
  • 数据很多但未标注
  • 数据很少且未标注
  • 未知

情况1:数据虽少但有标注

解决方法:在原始数据中进行模型训练,然后在目标数据中进行微调,但要小心过拟合问题。

情况2:数据很多但未标注

解决思路:训练两个Feature Extractor(network),一个用来抽取source domain图片特征,另一个用来抽取target domain图片特征。强迫让抽取的两个特征分布越接近越好,类似于只抽取共同特征,忽略了颜色这件事。

那怎么找出这样的这个Feature Extractor(network)?

使用Domain Adversarial Training(领域对抗训练)方法,如下图:

说明:

  • step1:把分类器拆解为两部分,即绿色的Feature Extractor + 蓝色的Label Predictor,Feature Extractor用来对输入图片抽取特征,Label Predictor用来对Feature Extractor的输出向量进行数字分类,哪几层是可以作为Feature Extractor,哪几层可以作为Label Predictor,这像超参数一样是需要调的;
  • step2:有带标注的source image通过Feature Extractor得到Feature向量,再经过Label Predictor,期待它与真值越接近越好;
  • step3:无标注的target image通过Feature Extractor得到Feature向量,但它是unlabeled,没有真值,所以不能继续走Label Predictor这条路;
  • step4:将步骤2和步骤3经过Feature Extractor的特征向量展开,记为blue points和red points,我们的目标就是要让red points的分布越接近blue points越好。

那如何让red points的分布越接近blue points越好呢?

答:听起来很像是分类器的相反,分类器是想让red points和blue points分得越开loss就越低。那我们想分不开就去惩罚这个Feature Extractor就好了。直白的讲,你就对Feature Extractor说,你今天要是抽出来的特征让那个Domain Classifier(领域分类)很容易分辨出来的话,我就惩罚你。

所以,Domain Classifier的loss越低,对Feature Extractor来说就是一件需要接受惩罚的事情。

因此,我们想得到Feature Extractor的话,有两个约束条件:

  • Feature Extractor要学会数字辨识这件事,也就是Feature Extractor + Label Predictor(绿色 + 蓝色)这条路线的Loss要低;
  • Feature Extractor要学会"存同弃异"这件事,也就是Feature Extractor + Domain Classifier(绿色 + 橙色)这条路线的 -Loss 要低;

其实就是GAN的思想,使用公式表达,如下图:

以上的原理是清楚了,下面就是一些小技巧:

举例而言,其实我们是希望unlabeled的data能够离边界越远越好:

  • 如果无标签数据靠近决策边界,说明模型在这个区域的分类不确定性高,容易出错。
  • 如果无标签数据远离决策边界,说明模型对这些样本的分类信心更高,泛化性能更好。

具体可以参考以下paper:

另外一种情况是:source domain的类别和target domain的类别不一样多,怎么解决?

参考以下链接:

情况3:数据很少且未标注

参考Test Time Training(TTT)链接:

情况4:未知

这种问题称为Domain Generalization。论文链接放下面:

相关推荐
若兰幽竹7 小时前
【大模型应用】抖音爆款视频深度分析系统:流水线式AI逆向拆解流量密码,精准预测播放量!
人工智能·python·音视频·抖音爆款分析
AI技术控7 小时前
NeuroH-TGL 论文解读:面向脑疾病诊断的神经异质性引导时序图学习方法
人工智能·语言模型·自然语言处理·langchain·nlp
fuquxiaoguang7 小时前
微软Maia 200的“算力经济学”:推理时代的专用芯片如何改写游戏规则
人工智能·microsoft
心中有国也有家7 小时前
pytorch-adapter:让 PyTorch 模型“无缝”跑在昇腾 NPU 上
人工智能·pytorch·笔记·python·学习
Sharewinfo_BJ7 小时前
从手工报表到实时BI:一个零售数据平台的踩坑与重构实战
大数据·人工智能·科技·数据分析·微软·powerbi
Elastic 中国社区官方博客7 小时前
在 Elasticsearch 中,存储向量查询速度最高提升 3 倍
大数据·人工智能·elasticsearch·搜索引擎·ai·全文检索
Cosolar8 小时前
从零搭建本地 RAG 系统:LangChain + LM Studio 完整实战指南
人工智能·后端·面试
weixin_436182428 小时前
一站式 ECAD 模型 AI 查询 专业设计辅助工具
人工智能
ting94520008 小时前
Fere AI 技术深度解析:面向加密货币与预测市场的自主交易智能体架构
人工智能·架构
生成论实验室8 小时前
通用人工智能完整技术方案:一个基于字序生命模型(WOLM)认知决策层实时、安全、可交互的数字生命体
人工智能·机器人·自动驾驶·agi·安全架构