tensorflow Keras Sequential 模型

深度解析 Keras Sequential 模型

你提供的内容是 Keras 中 Sequential 模型的核心说明,下面从核心概念、代码解析、适用/不适用场景、扩展用法 四个维度展开讲解,帮你彻底理解这个基础且常用的模型结构。


1. 核心概念:Sequential 模型的本质

Sequential(序贯模型)是 Keras 中最基础的模型类型,本质是线性堆叠的层容器 ------层与层之间严格按顺序连接,每层只有一个输入张量一个输出张量,整个模型的数据流是单向的"一条线"。

这也是它和 Functional API(函数式 API)的核心区别:Sequential 只支持"线性拓扑",而 Functional API 支持任意拓扑(多输入/输出、残差连接等)。


2. 示例代码逐行解析

(1)基础导入

python 复制代码
import tensorflow as tf
import keras
from keras import layers
  • TensorFlow 2.x 中 Keras 已集成到 tf.keras,也可以直接写 from tensorflow import keras
  • layers 模块包含所有 Keras 层(如 Dense 全连接层、Conv2D 卷积层等)。

(2)Sequential 模型定义 & 等价手动调用

python 复制代码
# 方式1:Sequential 封装
model = keras.Sequential(
    [
        layers.Dense(2, activation="relu", name="layer1"),  # 全连接层,输出维度2,ReLU激活
        layers.Dense(3, activation="relu", name="layer2"),  # 输出维度3,ReLU激活
        layers.Dense(4, name="layer3"),                      # 输出维度4,无激活(默认线性激活)
    ]
)

# 测试输入:3个样本,每个样本3个特征(形状 (3,3))
x = tf.ones((3, 3))
y = model(x)  # 前向传播

# 方式2:手动逐层调用(和Sequential等价)
layer1 = layers.Dense(2, activation="relu", name="layer1")
layer2 = layers.Dense(3, activation="relu", name="layer2")
layer3 = layers.Dense(4, name="layer3")

y_manual = layer3(layer2(layer1(x)))  # 逐层传递输入
关键细节:
  • 输入 x = tf.ones((3,3))(样本数, 特征数),这里 3 个样本,每个样本 3 维特征;
  • 层的维度匹配:
    • layer1 输入维度=3 → 输出维度=2(形状 (3,2));
    • layer2 输入维度=2 → 输出维度=3(形状 (3,3));
    • layer3 输入维度=3 → 输出维度=4(形状 (3,4));
  • 最终 yy_manual 的形状都是 (3,4),且值完全一致(因为 Sequential 只是"封装了逐层调用");
  • name 参数:给层命名,方便后续查看模型结构、调试或提取层输出。

3. Sequential 模型的适用/不适用场景

✅ 适用场景

  • 简单的"层堆叠"任务:比如 MNIST 手写数字分类(输入→卷积→池化→全连接→输出)、简单的回归任务(如房价预测);
  • 快速原型开发:代码简洁,无需手动定义输入输出的连接关系。

❌ 不适用场景(必须用 Functional API)

(1)多输入/多输出模型

比如:同时输入"文本+图片"预测一个标签,或一个模型同时输出"分类结果+置信度"。

(2)层有多个输入/输出

比如:融合两个不同分支的特征(如 ResNet 的残差连接,需要把"主分支"和"捷径分支"的输出拼接)。

(3)层共享

比如:多个输入共享同一个 Embedding 层(如问答系统中,问题和答案共享词嵌入层)。

(4)非线性拓扑

比如:残差连接(ResNet)、多分支模型(Inception)、循环依赖等。


4. Sequential 模型的扩展用法

(1)用 add() 动态添加层

除了初始化时传入层列表,也可以空初始化后逐个添加:

python 复制代码
model = keras.Sequential()
model.add(layers.Dense(2, activation="relu", name="layer1", input_shape=(3,)))  # 显式指定输入维度
model.add(layers.Dense(3, activation="relu", name="layer2"))
model.add(layers.Dense(4, name="layer3"))
  • input_shape=(3,):显式指定输入维度(样本维度省略),适合提前查看模型结构(否则需先调用模型才能生成 summary)。

(2)查看模型结构

python 复制代码
model.summary()

输出如下(清晰展示每层的名称、输出形状、参数数量):

复制代码
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 layer1 (Dense)              (None, 2)                 8         # 3*2 + 2(偏置) = 8
                                                                 
 layer2 (Dense)              (None, 3)                 9         # 2*3 + 3 = 9
                                                                 
 layer3 (Dense)              (None, 4)                 16        # 3*4 + 4 = 16
                                                                 
=================================================================
Total params: 33
Trainable params: 33
Non-trainable params: 0
_________________________________________________________________

(3)编译 & 训练(完整流程)

Sequential 模型的编译/训练和其他 Keras 模型一致:

python 复制代码
# 编译:指定优化器、损失函数、评估指标
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.MeanSquaredError(),  # 回归任务用MSE
    metrics=["mae"]  # 平均绝对误差
)

# 构造模拟数据
X_train = tf.random.normal((1000, 3))  # 1000个样本,3维特征
y_train = tf.random.normal((1000, 4))  # 1000个标签,4维输出

# 训练
history = model.fit(
    X_train, y_train,
    batch_size=32,
    epochs=10,
    validation_split=0.2  # 20%数据作为验证集
)

5. 总结

  • Sequential 是 Keras 中最简单的模型,适合线性堆叠的层结构
  • 核心优势:代码简洁、易上手,适合入门和简单任务;
  • 核心局限:不支持复杂拓扑,遇到多输入/输出、残差连接等场景,需切换到 Functional API(函数式 API);
  • 本质:Sequential 只是"语法糖",等价于手动逐层调用层,没有额外的功能扩展。

如果需要进一步了解 Functional API 对比 Sequential 的示例,可以告诉我,我会补充相关代码!

相关推荐
king王一帅31 分钟前
Incremark Solid 版本上线:Vue/React/Svelte/Solid 四大框架,统一体验
前端·javascript·人工智能
泰迪智能科技3 小时前
分享|职业技术培训|数字技术应用工程师快问快答
人工智能
Dxy12393102165 小时前
如何给AI提问:让机器高效理解你的需求
人工智能
少林码僧5 小时前
2.31 机器学习神器项目实战:如何在真实项目中应用XGBoost等算法
人工智能·python·算法·机器学习·ai·数据挖掘
钱彬 (Qian Bin)5 小时前
项目实践15—全球证件智能识别系统(切换为Qwen3-VL-8B-Instruct图文多模态大模型)
人工智能·算法·机器学习·多模态·全球证件识别
没学上了5 小时前
CNNMNIST
人工智能·深度学习
宝贝儿好5 小时前
【强化学习】第六章:无模型控制:在轨MC控制、在轨时序差分学习(Sarsa)、离轨学习(Q-learning)
人工智能·python·深度学习·学习·机器学习·机器人
智驱力人工智能6 小时前
守护流动的规则 基于视觉分析的穿越导流线区检测技术工程实践 交通路口导流区穿越实时预警技术 智慧交通部署指南
人工智能·opencv·安全·目标检测·计算机视觉·cnn·边缘计算
AI产品备案6 小时前
生成式人工智能大模型备案制度与发展要求
人工智能·深度学习·大模型备案·算法备案·大模型登记
AC赳赳老秦6 小时前
DeepSeek 私有化部署避坑指南:敏感数据本地化处理与合规性检测详解
大数据·开发语言·数据库·人工智能·自动化·php·deepseek