吴恩达深度学习复盘(14)迁移学习|项目基本周期

迁移学习

迁移学习是一种机器学习技术,它允许我们将从一个任务中学习到的知识应用到另一个相关的任务中。其核心思想在于,很多情况下,从头开始训练一个模型需要大量的数据和计算资源,而迁移学习能够复用在已有数据上训练好的模型的部分或全部,从而减少新任务的训练成本,加快模型收敛速度,提升模型在新任务上的性能,尤其是当新任务的数据量有限时,迁移学习的优势更为明显。

迁移学习的常见场景和方法

  • 特征提取:使用预训练模型作为特征提取器,去除原模型的最后几层(通常是全连接层),保留前面的卷积层或特征提取层。将新数据输入到这些层中,提取特征,然后将这些特征输入到一个新的简单模型(如全连接层)中进行训练。
  • 微调:在特征提取的基础上,不仅使用预训练模型的特征,还对预训练模型的部分或全部参数进行微调。通常会冻结预训练模型的前几层(因为这些层学习到的是通用特征),只对后面的几层进行训练。
  • 多任务学习:同时在多个相关任务上训练模型,使得模型能够学习到不同任务之间的共同特征和模式。

简单代码示例

下面是一个使用 Python 和 Keras 库进行迁移学习的简单例子,使用预训练的 VGG16 模型(也可以从hf上找一个其它模型代替)对猫狗图像进行分类。

python 复制代码
import os
import numpy as np
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# 数据路径
train_data_dir = 'path/to/train_data'
validation_data_dir = 'path/to/validation_data'

# 图像尺寸
img_width, img_height = 224, 224
batch_size = 32
epochs = 10

# 加载预训练的VGG16模型,不包含顶部的全连接层
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))

# 冻结预训练模型的所有层
for layer in base_model.layers:
    layer.trainable = False

# 添加自定义层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x)

# 构建新的模型
model = Model(inputs=base_model.input, outputs=predictions)

# 编译模型
model.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy', metrics=['accuracy'])

# 数据增强
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1. / 255)

# 加载训练数据和验证数据
train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

# 训练模型
model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // batch_size)
    

项目周期

到目前为止,已经讨论了很多关于如何训练模型以及如何为机器学习应用程序处理数据的问题。但在构建机器学习系统时,训练模型只是其中一部分。这一小节会总结机器学习项目的完整周期。

项目执行步骤

  • 项目范围界定:机器学习项目的第一步是确定项目内容,即决定要做什么。例如,决定研究用于语音搜索的语音识别,也就是通过对手机说话进行网络搜索,而非打字。
  • 数据收集:确定项目后,要决定需要哪些数据来训练机器学习系统,并着手获取音频数据以及相应标签的文本,这就是数据收集过程。
  • 模型训练与分析:收集初始数据后,开始训练模型,如训练语音识别系统并进行相关分析以改进模型。在训练模型并进行分析后,常常会发现可能需要返回去收集更多数据,可能是更多所有类型的数据,也可能是特定类型的数据,以提高学习算法的性能。比如,曾发现语音系统在有汽车噪音的背景下表现不佳,于是决定通过数据增强获取更多类似有车背景噪音的语音数据来提升算法性能。
  • 多次循环与部署:需要多次重复训练模型、进行分析以及收集更多数据这个循环过程,直到认为模型足够好,可以部署到生产环境中供用户使用。
  • 系统部署与维护:部署系统后,要持续监控系统性能,当性能变差时进行维护以恢复性能。有时部署后会发现模型效果不如预期,就需要回去重新训练模型或获取更多数据。而且,如果有权使用生产部署中的数据,这些数据可能会为进一步改进系统提供更多资源,有助于持续提高系统性能。

模型部署的

  • 以语音识别模型为例,常见的部署方式是将机器学习模型实现在服务器(推理服务器)上,其作用是调用训练好的模型进行预测。
  • 如果团队实现了一个移动应用程序,当用户与移动应用程序交谈时,移动应用程序可以通过 API 调用将录制的音频片段传递给推理服务器,推理服务器再将模型的预测结果,返回给移动应用程序。
  • 这种实现方式需要一定的软件工程来编写相关代码,根据应用程序服务用户数量的不同(从少数用户到数百万用户),所需的软件工作量和资源也会有很大差异。为大量用户服务可能需要使用特定技术来优化,管理计算成本和确保服务器可靠、高效地进行预测,同时通常需要记录输入数据 X 和预测结果,在用户隐私和同意的前提下,这些数据对系统监控非常有用。例如,曾建立的语音识别系统在遇到新名人或新政治家名字不在训练集里时表现不佳,通过监控系统发现数据变化和模型准确性下降,从而能够重新训练模型并进行更新。

笔者注:这篇其实讲的是机器学习的一个领域,叫做 MLOps,它涉及如何系统地构建、部署和维护机器学习系统,以确保模型可靠、性能良好、受到监控,并能适时进行更新以保持良好运行。当系统要部署给数百万人时,需要确保实现高度优化,以降低服务成本。软件要部署,归根到底还是必须考虑成本问题,这种优化是无止尽的。

相关推荐
DragonnAi32 分钟前
基于项目管理的轻量级目标检测自动标注系统【基于 YOLOV8】
人工智能·yolo·目标检测
AI绘画咪酱1 小时前
【CSDN首发】Stable Diffusion从零到精通学习路线分享
人工智能·学习·macos·ai作画·stable diffusion·aigc
DeepSeek+NAS1 小时前
耘想WinNAS:以聊天交互重构NAS生态,开启AI时代的存储革命
人工智能·重构·nas·winnas·安卓nas·windows nas
2201_754918411 小时前
OpenCv--换脸
人工智能·opencv·计算机视觉
ocr_sinosecu11 小时前
OCR进化史:从传统到深度学习,解锁文字识别新境界
人工智能·深度学习·ocr
Stara05111 小时前
YOLO11改进——融合BAM注意力机制增强图像分类与目标检测能力
人工智能·python·深度学习·目标检测·计算机视觉·yolov11
movigo7_dou1 小时前
关于深度学习局部视野与全局视野的一些思考
人工智能·深度学习
itwangyang5202 小时前
AIDD-人工智能药物设计-大语言模型在医学领域的革命性应用
人工智能·语言模型·自然语言处理
热心网友俣先生2 小时前
2025年泰迪杯数据挖掘竞赛B题论文首发+问题一二三四代码分享
人工智能·数据挖掘