PyTorch 与 TensorFlow 中基于自定义层的 DNN 实现对比

深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼


目录

    • [`深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼`](#深度学习双雄对决:PyTorch vs TensorFlow 自定义层大比拼)
    • [一、TensorFlow 实现 DNN](#一、TensorFlow 实现 DNN)
      • [1. 核心逻辑](#1. 核心逻辑)
    • [二、PyTorch 实现自定义层](#二、PyTorch 实现自定义层)
      • [1. 核心逻辑](#1. 核心逻辑)
    • 三、关键差异对比
    • 四、总结

一、TensorFlow 实现 DNN

1. 核心逻辑

  • 直接继承 tf.keras.layers.Layer :无需中间类,直接在 build 中定义多层结构。
  • 动态参数管理 :通过 add_weight 注册每一层的权重和偏置。
python 复制代码
import tensorflow as tf

class CustomDNNLayer(tf.keras.layers.Layer):
    def __init__(self, hidden_units, output_dim, **kwargs):
        super(CustomDNNLayer, self).__init__(**kwargs)
        self.hidden_units = hidden_units
        self.output_dim = output_dim

    def build(self, input_shape):
        # 输入层到第一个隐藏层
        self.w1 = self.add_weight(
            name='w1', 
            shape=(input_shape[-1], self.hidden_units[0]),
            initializer='random_normal',
            trainable=True
        )
        self.b1 = self.add_weight(
            name='b1',
            shape=(self.hidden_units[0],),
            initializer='zeros',
            trainable=True
        )

        # 隐藏层之间
        self.ws = []
        self.bs = []
        for i in range(len(self.hidden_units) - 1):
            self.ws.append(self.add_weight(
                name=f'w{i+2}', 
                shape=(self.hidden_units[i], self.hidden_units[i+1]),
                initializer='random_normal',
                trainable=True
            ))
            self.bs.append(self.add_weight(
                name=f'b{i+2}',
                shape=(self.hidden_units[i+1],),
                initializer='zeros',
                trainable=True
            ))

        # 输出层
        self.wo = self.add_weight(
            name='wo',
            shape=(self.hidden_units[-1], self.output_dim),
            initializer='random_normal',
            trainable=True
        )
        self.bo = self.add_weight(
            name='bo',
            shape=(self.output_dim,),
            initializer='zeros',
            trainable=True
        )

    def call(self, inputs):
        x = tf.matmul(inputs, self.w1) + self.b1
        x = tf.nn.relu(x)

        for i in range(len(self.hidden_units) - 1):
            x = tf.matmul(x, self.ws[i]) + self.bs[i]
            x = tf.nn.relu(x)

        x = tf.matmul(x, self.wo) + self.bo
        return x

二、PyTorch 实现自定义层

1. 核心逻辑

  • 继承 nn.Module:自定义层本质是模块的组合。
  • 使用 nn.ModuleList :动态管理多个 nn.Linear 层。
python 复制代码
import torch
import torch.nn as nn

class CustomPyTorchDNN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(CustomPyTorchDNN, self).__init__()
        self.hidden_layers = nn.ModuleList()
        prev_size = input_size

        # 动态添加隐藏层
        for hidden_size in hidden_sizes:
            self.hidden_layers.append(nn.Linear(prev_size, hidden_size))
            prev_size = hidden_size

        # 输出层
        self.output_layer = nn.Linear(prev_size, output_size)

    def forward(self, x):
        for layer in self.hidden_layers:
            x = torch.relu(layer(x))
        x = self.output_layer(x)
        return x

三、关键差异对比

维度 TensorFlow 实现 PyTorch 实现
类继承方式 直接继承 tf.keras.layers.Layer,无中间类。 继承 nn.Module,通过 nn.ModuleList 管理子模块。
参数管理 build 中显式注册每层权重(add_weight)。 自动注册所有 nn.Linear 参数(无需手动操作)。
前向传播定义 通过 call 方法逐层计算,需手动处理每层的权重和激活函数。 通过 forward 方法逐层调用 nn.Linear,激活函数手动插入。
灵活性 更底层,适合完全自定义逻辑(如非线性变换、特殊参数初始化)。 更简洁,适合快速构建标准网络结构。
训练流程 需手动实现训练循环(反向传播 + 优化器)。 需手动实现训练循环(与 TensorFlow 类似)。

四、总结

  • TensorFlow :通过直接继承 tf.keras.layers.Layer,可实现完全自定义的 DNN,但需手动管理多层权重和激活逻辑,适合对模型细节有严格控制需求的场景。
  • PyTorch :通过直接继承 nn.Module,可实现完全自定义的 DNN;利用 nn.ModuleListnn.Linear 的组合,能高效构建标准 DNN 结构,代码简洁且易于扩展,适合快速原型开发和研究场景。

两种实现均满足用户对"直接继承核心类 + 使用基础组件"的要求,可根据具体任务选择框架。

相关推荐
C嘎嘎嵌入式开发11 小时前
(一) 机器学习之深度神经网络
人工智能·神经网络·dnn
盼小辉丶13 小时前
TensorFlow深度学习实战(39)——机器学习实践指南
深度学习·机器学习·tensorflow
蒋星熠13 小时前
反爬虫机制深度解析:从基础防御到高级对抗的完整技术实战
人工智能·pytorch·爬虫·python·深度学习·机器学习·计算机视觉
it技术16 小时前
Pytorch项目实战 :基于RNN的实现情感分析
pytorch·后端
java1234_小锋21 小时前
TensorFlow2 Python深度学习 - TensorFlow2框架入门 - 变量(Variable)的定义与操作
python·深度学习·tensorflow·tensorflow2
mooooon L1 天前
DAY 43 复习日-2025.10.7
人工智能·pytorch·python·深度学习·神经网络
ting_zh1 天前
PyTorch、TensorFlow、JAX 简介
人工智能·pytorch·tensorflow
java1234_小锋1 天前
TensorFlow2 Python深度学习 - 深度学习概述
python·深度学习·tensorflow·tensorflow2·python深度学习
通往曙光的路上2 天前
国庆回来的css
人工智能·python·tensorflow
wa的一声哭了2 天前
Stanford CS336 assignment1 | Transformer Language Model Architecture
人工智能·pytorch·python·深度学习·神经网络·语言模型·transformer