PyTorch 与 TensorFlow 中基于自定义层的 DNN 实现对比

深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼


目录

    • [`深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼`](#深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼)
    • [一、TensorFlow 实现 DNN](#一、TensorFlow 实现 DNN)
      • [1. 核心逻辑](#1. 核心逻辑)
    • [二、PyTorch 实现自定义层](#二、PyTorch 实现自定义层)
      • [1. 核心逻辑](#1. 核心逻辑)
    • 三、关键差异对比
    • 四、总结

一、TensorFlow 实现 DNN

1. 核心逻辑

  • 直接继承 tf.keras.layers.Layer :无需中间类,直接在 build 中定义多层结构。
  • 动态参数管理 :通过 add_weight 注册每一层的权重和偏置。
python 复制代码
import tensorflow as tf

class CustomDNNLayer(tf.keras.layers.Layer):
    def __init__(self, hidden_units, output_dim, **kwargs):
        super(CustomDNNLayer, self).__init__(**kwargs)
        self.hidden_units = hidden_units
        self.output_dim = output_dim

    def build(self, input_shape):
        # 输入层到第一个隐藏层
        self.w1 = self.add_weight(
            name='w1', 
            shape=(input_shape[-1], self.hidden_units[0]),
            initializer='random_normal',
            trainable=True
        )
        self.b1 = self.add_weight(
            name='b1',
            shape=(self.hidden_units[0],),
            initializer='zeros',
            trainable=True
        )

        # 隐藏层之间
        self.ws = []
        self.bs = []
        for i in range(len(self.hidden_units) - 1):
            self.ws.append(self.add_weight(
                name=f'w{i+2}', 
                shape=(self.hidden_units[i], self.hidden_units[i+1]),
                initializer='random_normal',
                trainable=True
            ))
            self.bs.append(self.add_weight(
                name=f'b{i+2}',
                shape=(self.hidden_units[i+1],),
                initializer='zeros',
                trainable=True
            ))

        # 输出层
        self.wo = self.add_weight(
            name='wo',
            shape=(self.hidden_units[-1], self.output_dim),
            initializer='random_normal',
            trainable=True
        )
        self.bo = self.add_weight(
            name='bo',
            shape=(self.output_dim,),
            initializer='zeros',
            trainable=True
        )

    def call(self, inputs):
        x = tf.matmul(inputs, self.w1) + self.b1
        x = tf.nn.relu(x)

        for i in range(len(self.hidden_units) - 1):
            x = tf.matmul(x, self.ws[i]) + self.bs[i]
            x = tf.nn.relu(x)

        x = tf.matmul(x, self.wo) + self.bo
        return x

二、PyTorch 实现自定义层

1. 核心逻辑

  • 继承 nn.Module:自定义层本质是模块的组合。
  • 使用 nn.ModuleList :动态管理多个 nn.Linear 层。
python 复制代码
import torch
import torch.nn as nn

class CustomPyTorchDNN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(CustomPyTorchDNN, self).__init__()
        self.hidden_layers = nn.ModuleList()
        prev_size = input_size

        # 动态添加隐藏层
        for hidden_size in hidden_sizes:
            self.hidden_layers.append(nn.Linear(prev_size, hidden_size))
            prev_size = hidden_size

        # 输出层
        self.output_layer = nn.Linear(prev_size, output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = torch.relu(layer(x))
        x = self.output_layer(x)
        return x

三、关键差异对比

维度 TensorFlow 实现 PyTorch 实现
类继承方式 直接继承 tf.keras.layers.Layer,无中间类。 继承 nn.Module,通过 nn.ModuleList 管理子模块。
参数管理 build 中显式注册每层权重(add_weight)。 自动注册所有 nn.Linear 参数(无需手动操作)。
前向传播定义 通过 call 方法逐层计算,需手动处理每层的权重和激活函数。 通过 forward 方法逐层调用 nn.Linear,激活函数手动插入。
灵活性 更底层,适合完全自定义逻辑(如非线性变换、特殊参数初始化)。 更简洁,适合快速构建标准网络结构。
训练流程 需手动实现训练循环(反向传播 + 优化器)。 需手动实现训练循环(与 TensorFlow 类似)。

四、总结

  • TensorFlow :通过直接继承 tf.keras.layers.Layer,可实现完全自定义的 DNN,但需手动管理多层权重和激活逻辑,适合对模型细节有严格控制需求的场景。
  • PyTorch :通过直接继承 nn.Module,可实现完全自定义的 DNN;利用 nn.ModuleListnn.Linear 的组合,能高效构建标准 DNN 结构,代码简洁且易于扩展,适合快速原型开发和研究场景。

两种实现均满足用户对"直接继承核心类 + 使用基础组件"的要求,可根据具体任务选择框架。

相关推荐
2501_915374352 小时前
深入理解 TensorFlow 的模型保存与加载机制(SavedModel vs H5)
人工智能·tensorflow
一点.点3 小时前
PyTorch常用命令(可快速上手PyTorch的核心功能,涵盖从数据预处理到模型训练的全流程)
人工智能·pytorch·深度学习
贝塔西塔3 小时前
时间序列数据集构建方案Pytorch
人工智能·pytorch·深度学习
试着15 小时前
【AI面试准备】TensorFlow与PyTorch构建缺陷预测模型
人工智能·pytorch·面试·tensorflow·测试
郜太素17 小时前
PyTorch 张量与自动微分操作
人工智能·pytorch·python·深度学习·学习方法·张量·自动微分
sheng_er_sheng17 小时前
【笔记】【B站课程 pytorch】梯度下降模型
人工智能·pytorch·笔记
DevangLic17 小时前
【CUDA pytorch】
人工智能·pytorch·python
QQ6765800817 小时前
PyTorch和torchvision为例,如何使用预训练的ResNet模型来训练水稻虫害分类数据集 14类 从数据准备到模型训练、评估全流程
人工智能·pytorch·分类
令狐少侠201119 小时前
PaddlePaddle 和PyTorch选择与对比互斥
人工智能·pytorch·paddlepaddle