PyTorch 与 TensorFlow 中基于自定义层的 DNN 实现对比

深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼


目录

    • [`深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼`](#深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼)
    • [一、TensorFlow 实现 DNN](#一、TensorFlow 实现 DNN)
      • [1. 核心逻辑](#1. 核心逻辑)
    • [二、PyTorch 实现自定义层](#二、PyTorch 实现自定义层)
      • [1. 核心逻辑](#1. 核心逻辑)
    • 三、关键差异对比
    • 四、总结

一、TensorFlow 实现 DNN

1. 核心逻辑

  • 直接继承 tf.keras.layers.Layer :无需中间类,直接在 build 中定义多层结构。
  • 动态参数管理 :通过 add_weight 注册每一层的权重和偏置。
python 复制代码
import tensorflow as tf

class CustomDNNLayer(tf.keras.layers.Layer):
    def __init__(self, hidden_units, output_dim, **kwargs):
        super(CustomDNNLayer, self).__init__(**kwargs)
        self.hidden_units = hidden_units
        self.output_dim = output_dim

    def build(self, input_shape):
        # 输入层到第一个隐藏层
        self.w1 = self.add_weight(
            name='w1', 
            shape=(input_shape[-1], self.hidden_units[0]),
            initializer='random_normal',
            trainable=True
        )
        self.b1 = self.add_weight(
            name='b1',
            shape=(self.hidden_units[0],),
            initializer='zeros',
            trainable=True
        )

        # 隐藏层之间
        self.ws = []
        self.bs = []
        for i in range(len(self.hidden_units) - 1):
            self.ws.append(self.add_weight(
                name=f'w{i+2}', 
                shape=(self.hidden_units[i], self.hidden_units[i+1]),
                initializer='random_normal',
                trainable=True
            ))
            self.bs.append(self.add_weight(
                name=f'b{i+2}',
                shape=(self.hidden_units[i+1],),
                initializer='zeros',
                trainable=True
            ))

        # 输出层
        self.wo = self.add_weight(
            name='wo',
            shape=(self.hidden_units[-1], self.output_dim),
            initializer='random_normal',
            trainable=True
        )
        self.bo = self.add_weight(
            name='bo',
            shape=(self.output_dim,),
            initializer='zeros',
            trainable=True
        )

    def call(self, inputs):
        x = tf.matmul(inputs, self.w1) + self.b1
        x = tf.nn.relu(x)

        for i in range(len(self.hidden_units) - 1):
            x = tf.matmul(x, self.ws[i]) + self.bs[i]
            x = tf.nn.relu(x)

        x = tf.matmul(x, self.wo) + self.bo
        return x

二、PyTorch 实现自定义层

1. 核心逻辑

  • 继承 nn.Module:自定义层本质是模块的组合。
  • 使用 nn.ModuleList :动态管理多个 nn.Linear 层。
python 复制代码
import torch
import torch.nn as nn

class CustomPyTorchDNN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(CustomPyTorchDNN, self).__init__()
        self.hidden_layers = nn.ModuleList()
        prev_size = input_size

        # 动态添加隐藏层
        for hidden_size in hidden_sizes:
            self.hidden_layers.append(nn.Linear(prev_size, hidden_size))
            prev_size = hidden_size

        # 输出层
        self.output_layer = nn.Linear(prev_size, output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = torch.relu(layer(x))
        x = self.output_layer(x)
        return x

三、关键差异对比

维度 TensorFlow 实现 PyTorch 实现
类继承方式 直接继承 tf.keras.layers.Layer,无中间类。 继承 nn.Module,通过 nn.ModuleList 管理子模块。
参数管理 build 中显式注册每层权重(add_weight)。 自动注册所有 nn.Linear 参数(无需手动操作)。
前向传播定义 通过 call 方法逐层计算,需手动处理每层的权重和激活函数。 通过 forward 方法逐层调用 nn.Linear,激活函数手动插入。
灵活性 更底层,适合完全自定义逻辑(如非线性变换、特殊参数初始化)。 更简洁,适合快速构建标准网络结构。
训练流程 需手动实现训练循环(反向传播 + 优化器)。 需手动实现训练循环(与 TensorFlow 类似)。

四、总结

  • TensorFlow :通过直接继承 tf.keras.layers.Layer,可实现完全自定义的 DNN,但需手动管理多层权重和激活逻辑,适合对模型细节有严格控制需求的场景。
  • PyTorch :通过直接继承 nn.Module,可实现完全自定义的 DNN;利用 nn.ModuleListnn.Linear 的组合,能高效构建标准 DNN 结构,代码简洁且易于扩展,适合快速原型开发和研究场景。

两种实现均满足用户对"直接继承核心类 + 使用基础组件"的要求,可根据具体任务选择框架。

相关推荐
weiwei228442 天前
神经网络模型导出及开放标准格式ONNX
pytorch·onnx
程序猿追11 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
闵孚龙11 天前
《PyTorch 深度修炼》Dataset 和 DataLoader:数据如何喂给模型
人工智能·pytorch·python
bryant_meng11 天前
【VAE】From Pixels to Faces: Building a VAE from Scratch
pytorch·vae·log-sigma2·重参数
装不满的克莱因瓶11 天前
了解多标签图像分类方法——从Sigmoid输出到真实世界复杂视觉理解
人工智能·pytorch·python·深度学习·机器学习·分类·数据挖掘
冷小鱼11 天前
TensorFlow 2.21 进阶实战:从训练优化到生产部署的完整指南
人工智能·pytorch·python·tensorflow
冷小鱼11 天前
PyTorch 2.12 完全指南:从动态图到编译优化的深度学习框架演进
人工智能·pytorch·深度学习
IRevers11 天前
【大模型】Gemma4在ROCm和vLLM部署
人工智能·pytorch·深度学习·大模型·datawhale·vllm·amdev
盼小辉丶11 天前
PyTorch强化学习实战(14)——优先经验回放机制
pytorch·python·深度学习·强化学习
装不满的克莱因瓶11 天前
【工业领域】了解目标检测评估指标——从mAP到IoU的完整评价体系解析
人工智能·pytorch·python·深度学习·目标检测·计算机视觉·目标跟踪