PyTorch 与 TensorFlow 中基于自定义层的 DNN 实现对比

深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼


目录

    • [`深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼`](#深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼)
    • [一、TensorFlow 实现 DNN](#一、TensorFlow 实现 DNN)
      • [1. 核心逻辑](#1. 核心逻辑)
    • [二、PyTorch 实现自定义层](#二、PyTorch 实现自定义层)
      • [1. 核心逻辑](#1. 核心逻辑)
    • 三、关键差异对比
    • 四、总结

一、TensorFlow 实现 DNN

1. 核心逻辑

  • 直接继承 tf.keras.layers.Layer :无需中间类,直接在 build 中定义多层结构。
  • 动态参数管理 :通过 add_weight 注册每一层的权重和偏置。
python 复制代码
import tensorflow as tf

class CustomDNNLayer(tf.keras.layers.Layer):
    def __init__(self, hidden_units, output_dim, **kwargs):
        super(CustomDNNLayer, self).__init__(**kwargs)
        self.hidden_units = hidden_units
        self.output_dim = output_dim

    def build(self, input_shape):
        # 输入层到第一个隐藏层
        self.w1 = self.add_weight(
            name='w1', 
            shape=(input_shape[-1], self.hidden_units[0]),
            initializer='random_normal',
            trainable=True
        )
        self.b1 = self.add_weight(
            name='b1',
            shape=(self.hidden_units[0],),
            initializer='zeros',
            trainable=True
        )

        # 隐藏层之间
        self.ws = []
        self.bs = []
        for i in range(len(self.hidden_units) - 1):
            self.ws.append(self.add_weight(
                name=f'w{i+2}', 
                shape=(self.hidden_units[i], self.hidden_units[i+1]),
                initializer='random_normal',
                trainable=True
            ))
            self.bs.append(self.add_weight(
                name=f'b{i+2}',
                shape=(self.hidden_units[i+1],),
                initializer='zeros',
                trainable=True
            ))

        # 输出层
        self.wo = self.add_weight(
            name='wo',
            shape=(self.hidden_units[-1], self.output_dim),
            initializer='random_normal',
            trainable=True
        )
        self.bo = self.add_weight(
            name='bo',
            shape=(self.output_dim,),
            initializer='zeros',
            trainable=True
        )

    def call(self, inputs):
        x = tf.matmul(inputs, self.w1) + self.b1
        x = tf.nn.relu(x)

        for i in range(len(self.hidden_units) - 1):
            x = tf.matmul(x, self.ws[i]) + self.bs[i]
            x = tf.nn.relu(x)

        x = tf.matmul(x, self.wo) + self.bo
        return x

二、PyTorch 实现自定义层

1. 核心逻辑

  • 继承 nn.Module:自定义层本质是模块的组合。
  • 使用 nn.ModuleList :动态管理多个 nn.Linear 层。
python 复制代码
import torch
import torch.nn as nn

class CustomPyTorchDNN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(CustomPyTorchDNN, self).__init__()
        self.hidden_layers = nn.ModuleList()
        prev_size = input_size

        # 动态添加隐藏层
        for hidden_size in hidden_sizes:
            self.hidden_layers.append(nn.Linear(prev_size, hidden_size))
            prev_size = hidden_size

        # 输出层
        self.output_layer = nn.Linear(prev_size, output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = torch.relu(layer(x))
        x = self.output_layer(x)
        return x

三、关键差异对比

维度 TensorFlow 实现 PyTorch 实现
类继承方式 直接继承 tf.keras.layers.Layer,无中间类。 继承 nn.Module,通过 nn.ModuleList 管理子模块。
参数管理 build 中显式注册每层权重(add_weight)。 自动注册所有 nn.Linear 参数(无需手动操作)。
前向传播定义 通过 call 方法逐层计算,需手动处理每层的权重和激活函数。 通过 forward 方法逐层调用 nn.Linear,激活函数手动插入。
灵活性 更底层,适合完全自定义逻辑(如非线性变换、特殊参数初始化)。 更简洁,适合快速构建标准网络结构。
训练流程 需手动实现训练循环(反向传播 + 优化器)。 需手动实现训练循环(与 TensorFlow 类似)。

四、总结

  • TensorFlow :通过直接继承 tf.keras.layers.Layer,可实现完全自定义的 DNN,但需手动管理多层权重和激活逻辑,适合对模型细节有严格控制需求的场景。
  • PyTorch :通过直接继承 nn.Module,可实现完全自定义的 DNN;利用 nn.ModuleListnn.Linear 的组合,能高效构建标准 DNN 结构,代码简洁且易于扩展,适合快速原型开发和研究场景。

两种实现均满足用户对"直接继承核心类 + 使用基础组件"的要求,可根据具体任务选择框架。

相关推荐
盼小辉丶4 小时前
PyTorch实战——基于生成对抗网络生成服饰图像
pytorch·深度学习·生成对抗网络
西猫雷婶5 小时前
深度学习|pytorch基本运算-hadamard积、点积和矩阵乘法
pytorch·深度学习·矩阵
Steve lu13 小时前
回归任务损失函数对比曲线
人工智能·pytorch·深度学习·神经网络·算法·回归·原力计划
CC_IsMe18 小时前
Linux服务器 TensorFlow找不到GPU
linux·jupyter·ssh·conda·tensorflow
YYXZZ。。20 小时前
PyTorch ——torchvision数据集使用
人工智能·pytorch·python
IOsetting20 小时前
3D Gaussian splatting 05: 代码阅读-训练整体流程
pytorch·3d gaussian
KeepThinking!21 小时前
BLIP-2
人工智能·pytorch·深度学习·计算机视觉·blip2
永恒的溪流1 天前
无法运用pytorch环境、改环境路径、隔离环境
人工智能·pytorch·python·深度学习
西猫雷婶1 天前
深度学习|pytorch基本运算-广播失效
人工智能·pytorch·深度学习
Ronin-Lotus1 天前
深度学习篇---Pytorch框架下OC-SORT实现
人工智能·pytorch·python·深度学习·oc-sort