“Hello 神经网络!”

神经网络搭建和参数计算

在pytorch中定义深度神经网络其实就是层堆叠的过程,继承自nn.Module,实现两个方法:

  • _init_方法中定义网络中的层结构,主要是全连接层,并进行初始化
  • forward方法,在实例化模型的时候,底层会自动调用该函数。该函数中为初始化定义的
    layer传入数据,进行前向传播等。

code

python 复制代码
'''
神经网络搭建
步骤:准备数据-搭建神经网络-模型训练-模型测试

搭建神经网络步骤:定义一个类,继承nn.Module
               在__init__方法中定义网络的层
               在forward方法中定义网络的前向传播
'''
from cgi import print_arguments
import torch
import torch.nn as nn
from torchsummary import summary
import sys
sys.stdout.reconfigure(encoding='utf-8')  # Python 3.7+ 支持
#todo:搭建神经网络
class ModelDemo(nn.Module):
    #在__init__方法中定义网络的层
    def __init__(self):
        super().__init__()
        #搭建神经网络:隐藏层+输出层
        #隐藏层
        self.linear1=nn.Linear(3,3)
        self.linear2=nn.Linear(3,2)
        #输出层
        self.output=nn.Linear(2,2)

        #对隐藏层进行初始化
        nn.init.xavier_normal_(self.linear1.weight)
        nn.init.zeros_(self.linear1.bias)
        nn.init.xavier_normal_(self.linear2.weight)
        nn.init.zeros_(self.linear2.bias)
    

    def forward(self,x):
        #隐藏层:加权求和+激活函数
        #分解版写法
        # x=slef.liner1(x)
        # x=torch.sigmoid(x)
        #合并版
        x=torch.sigmoid(self.linear1(x))
        x=torch.relu(self.linear2(x))
        #dim=-1表示按行计算 
        x=torch.softmax(self.output(x),dim=-1)

        return x

def train():
    #创建模型
    model=ModelDemo()
    # print('我的模型:',model)

    #创建数据集
    data=torch.randn(size=(5,3))
    print('我的数据集:',data)
    print('我的数据集的形状:',data.shape)
    print('我的数据集的是否自动微分:',data.requires_grad)

    #调用神经网络
    output=model(data)
    print('我的输出:',output)
    print('我的输出的形状:',output.shape)
    print('我的输出的是否自动微分:',output.requires_grad)

    #计算和查看模型参数
    print('==============计算模型参数数===============')
    summary(model,input_size=(5,3))

    print('==============查看模型参数数===============')
    for name,param in model.named_parameters():
        print('神经网络层级:',name)
        print('参数:',param)
        print('--------------------------')

if __name__ == '__main__':
    train()
   
相关推荐
王_teacher1 小时前
RNN 循环神经网络 计算过程(通俗+公式版+运行实例)
人工智能·rnn·nlp
玩转单片机与嵌入式1 小时前
一个成熟的嵌入式AI系统,是长什么样子的?
人工智能·单片机·嵌入式硬件·嵌入式ai
曦樂~4 小时前
【机器学习】概述
人工智能·机器学习
DeniuHe4 小时前
机器学习模型中的偏置项(bias / 截距项)到底有什么用?
人工智能·机器学习
小江的记录本4 小时前
【网络安全】《网络安全常见攻击与防御》(附:《六大攻击核心特性横向对比表》)
java·网络·人工智能·后端·python·安全·web安全
思绪无限4 小时前
YOLOv5至YOLOv12升级:植物叶片病害识别系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·yolov12·yolo全家桶·植物叶片病害检测
白羊by4 小时前
YOLOv1~v11 全版本核心演进总览
深度学习·算法·yolo
深小乐5 小时前
AI 周刊【2026.04.13-04.19】:中美差距减小、Claude Opus 4.7发布、国产算力突围
人工智能
深小乐5 小时前
从 AI Skills 学实战技能(七):让 AI 自动操作浏览器
人工智能
workflower5 小时前
人机交互部分OOD
运维·人工智能·自动化·集成测试·人机交互·软件需求