Pytorch学习第二周--Day 12-13: 构建你的第一个神经网络

Day 12-13: 构建你的第一个神经网络

在这两天里,我动手实践构建了我的第一个神经网络,目的是解决一个基本的分类问题。使用了两个主流的深度学习框架:PyTorch和TensorFlow,以对比和理解它们在神经网络构建方面的不同。

目标:构建一个全连接的神经网络来处理分类问题。

过程:

设计网络结构,包括输入层、若干隐藏层和输出层。

选择合适的激活函数,如ReLU。

定义损失函数和优化器,例如使用交叉熵损失和Adam优化器。

实现:

在PyTorch中,我定义了一个nn.Module类,通过定义forward方法来实现数据的前向传播。

在TensorFlow中,我使用Sequential API来构建模型,这是一种更简洁、更高级的方法。

以下是具体的实现代码:

PyTorch代码示例

import torch

import torch.nn as nn

import torch.optim as optim

定义一个全连接神经网络

class FullyConnectedNN(nn.Module):

def init (self):

super(FullyConnectedNN, self).init ()

self.fc1 = nn.Linear(784, 128) # 假设输入是28x28图像,展平后的大小为784

self.relu = nn.ReLU()

self.fc2 = nn.Linear(128, 10) # 假设有10个类别

复制代码
def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

实例化模型

model = FullyConnectedNN()

定义损失函数和优化器

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

TensorFlow代码示例

import tensorflow as tf

from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Dense

定义一个全连接神经网络

model = Sequential([

Dense(128, activation='relu', input_shape=(784,)), # 假设输入是28x28图像,展平后的大小为784

Dense(10, activation='softmax') # 假设有10个类别

])

编译模型

model.compile(optimizer='adam',

loss='sparse_categorical_crossentropy',

metrics=['accuracy'])

在这两个实现中,我专注于构建一个相对简单的神经网络,适用于处理基本的分类问题。通过这个练习,我加深了对神经网络结构和深度学习框架的理解,并获得了实际操作的经验。

相关推荐
xiaotao13137 分钟前
02-机器学习基础: 监督学习——线性回归
学习·机器学习·线性回归
墨黎芜1 小时前
ArcGIS从入门到精通——地图符号、注记的初步使用
学习·arcgis·信息可视化
小李云雾2 小时前
FastAPI重要知识点---中间件(Middleware)
学习·程序人生·中间件·fastapi·middleware
小夏子_riotous2 小时前
Docker学习路径——3、常用命令
linux·运维·服务器·学习·docker·容器·centos
STLearner2 小时前
WSDM 2026 | 时间序列(Time Series)论文总结【预测,表示学习,因果】
大数据·论文阅读·人工智能·深度学习·学习·机器学习·数据挖掘
redaijufeng2 小时前
网络爬虫学习:应用selenium获取Edge浏览器版本号,自动下载对应版本msedgedriver,确保Edge浏览器顺利打开。
爬虫·学习·selenium
昵称小白3 小时前
从 ( y = wx + b ) 到神经网络:参数、loss、梯度到底怎么连起来(一)
人工智能·神经网络
腾科IT教育3 小时前
零基础快速上岸HCIP,高效学习思路分享
学习·华为认证·hcip·hcip考试·hcip认证
23471021273 小时前
4.14 学习笔记
笔记·python·学习
醇氧3 小时前
【学习】软件过程模型全解析:从瀑布到敏捷的演进之路
学习·log4j