如何使用PyTorch搭建一个基础的神经网络并进行训练?

神经网络作为人工智能的重要组成部分,在图像处理、自然语言处理、语音识别、机器翻译等领域具有广泛的应用。今天的分享就来详细的介绍如何搭建一个简单的神经网络框架并进行训练。

首先定义一个名为NeuralNetwork的类,它继承了PyTorch框架的nn.Module类,用于创建神经网络。接下来部分是类的构造函数__init__(),用于初始化神经网络的各个层。在这个类中,初始化了以下层:

  • flatten:一个将输入展平的层。
  • hidden1:第一个隐藏层,输入大小为28x28(图像大小),输出大小为128。
  • hidden2:第二个隐藏层,输入大小为128,输出大小为128。
  • hidden3:第三个隐藏层,输入大小为128,输出大小为64。
  • out:输出层,输入大小为64,输出大小为10(类别数)。

    定义一个名为NeuralNetwork的类,它继承了PyTorch框架的nn.Module类,用于创建神经网络。classNeuralNetwork(nn.Module):def__init__(self):# 继承父类nn.Module的方法和属性super(NeuralNetwork, self).init()# 数据进行展平操作self.flatten=nn.Flatten()# 定义一层线性神经网络self.hidden1=nn.Linear(28*28,128)self.hidden2=nn.Linear(128,128)self.hidden3=nn.Linear(128,64)self.out=nn.Linear(64,10)

另外想自学机器学习深度学习,不知道如何入门的同学,还为大家整理了一份入门路线图(更新迭代不下10次),包含基础、理论、代码、实战项目、必读论文等等,希望可以帮到大家,大家可以添加小助手无偿自取****

这部分定义了前向传播方法forward(),通过前向传播计算输入数据x的输出。首先输入的数据先通过flatten层展平,然后依次经过隐藏层和激活函数进行线性变换和非线性处理。最后经过输出层输出预测结果x。(注意这个函数也是定义在NeuralNetwork之中的。)

复制代码
defforward(self,x):        # 对输入数据进行展平操作x=self.flatten(x)        # 将数据传入第一层线性神经网络x=self.hidden1(x)        # 对神经网络的输出应用ReLU激活函数x=torch.relu(x)x=self.hidden2(x)        # 对神经网络的输出应用sigmoid激活函数x=torch.sigmoid(x)x=self.hidden3(x)x=torch.relu(x)        # 将数据传入输出层x=self.out(x)        # 最后经过输出层输出预测结果xreturnx

这行代码创建了一个model的神经网络模型实例,并将其移动到特定的设备(我这里使用的是GPU)上进行计算。

复制代码
model=NeuralNetwork().to(device)

这行代码创建了一个CrossEntropyLoss(交叉熵损失函数)的实例,用作损失函数。​​​​​​​

复制代码
#交叉熵损失函数loss_fn=nn.CrossEntropyLoss()这行代码创建了一个Adam优化器的实例,将模型参数和学习率作为参数传入。​​​​​​​

# model.parameters()用于获取模型的所有参数。这些参数包括权重和偏差等# 它们都是Tensor类型的,是神经网络的重要组成部分optimizer=torch.optim.Adam(model.parameters(),lr=0.005)

这行代码定义了一个名为train的函数,用于进行训练。函数接受训练数据、模型、损失函数和优化器作为输入参数。​​​​​​​

复制代码
'''定义训练函数'''deftrain(dataloader,model,loss_fn,optimizer):# 设置模型为训练模式    model.train()# 记录优化次数    num=1# 遍历数据加载器中的每一个数据批次。for X,y in dataloader:        X,y=X.to(device),y.to(device)# 自动初始化权值w        pred=model.forward(X)        loss=loss_fn(pred,y) # 计算损失值# 将优化器的梯度缓存清零        optimizer.zero_grad()# 执行反向传播计算梯度        loss.backward()# 并通过优化器更新模型参数        optimizer.step()# 将损失值转换为标量        loss_value=loss.item()#将此次损失值打印出来        print(f'loss:{loss_value},[numbes]:{num}')#增加计数器num        num+=1

这行代码调用train函数来开始模型的训练。它传入训练数据、模型、损失函数和优化器,并执行训练过程。(注意训练前要先获取dataloader!)​​​​​​​

复制代码
#调用函数train(),传入训练数据,神经网络模型,损失函数和优化算法train(train_dataloader,model,loss_fn,optimizer)

完整代码展示:​​​​​​​

复制代码
'''定义神经网络'''classNeuralNetwork(nn.Module):def__init__(self):        super(NeuralNetwork, self).__init__()        self.flatten=nn.Flatten()        self.hidden1=nn.Linear(28*28,128)        self.hidden2=nn.Linear(128,128)        self.hidden3=nn.Linear(128,64)        self.out=nn.Linear(64,10)defforward(self,x):        x=self.flatten(x)        x=self.hidden1(x)        x=torch.relu(x)        x=self.hidden2(x)        x=torch.sigmoid(x)        x=self.hidden3(x)        x=torch.relu(x)        x=self.out(x)return xmodel=NeuralNetwork().to(device)print(model)'''建立损失函数和优化算法'''#交叉熵损失函数loss_fn=nn.CrossEntropyLoss()# 优化算法为随机梯度算法/Adam优化算法optimizer=torch.optim.Adam(model.parameters(),lr=0.005)'''定义训练函数'''deftrain(dataloader,model,loss_fn,optimizer):    model.train()# 记录优化次数    num=1for X,y in dataloader:        X,y=X.to(device),y.to(device)# 自动初始化权值w        pred=model.forward(X)        loss=loss_fn(pred,y) # 计算损失值        optimizer.zero_grad()        loss.backward()        optimizer.step()        loss_value=loss.item()        print(f'loss:{loss_value},[numbes]:{num}')        num+=1train(train_dataloader,model,loss_fn,optimizer)
相关推荐
2401_879693874 分钟前
用Python批量处理Excel和CSV文件
jvm·数据库·python
I'm Jie7 分钟前
Swagger UI 本地化部署,解决 FastAPI Swagger UI 依赖外部 CDN 加载失败问题
python·ui·fastapi·swagger·swagger ui
koo36410 分钟前
pytorch深度学习笔记23
pytorch·笔记·深度学习
风流 少年11 分钟前
Superpowers与 OpenSpec、Spec Kit 对比
ai
LaughingZhu21 分钟前
Product Hunt 每日热榜 | 2026-03-21
人工智能·经验分享·深度学习·神经网络·产品运营
qzhqbb22 分钟前
差分隐私与大模型+差分隐私在相关领域应用的论文总结
人工智能·算法
一招定胜负25 分钟前
基于通义千问 API 的课堂话语智能分类分析工具实现
人工智能·分类·数据挖掘
2401_8463416527 分钟前
Python Lambda(匿名函数):简洁之道
jvm·数据库·python
2401_8796938730 分钟前
进阶技巧与底层原理
jvm·数据库·python
阿_旭31 分钟前
基于YOLO26深度学习的【桃子成熟度检测与分割系统】【python源码+Pyqt5界面+数据集+训练代码】图像分割、人工智能
人工智能·python·深度学习·桃子成熟度检测