如何使用PyTorch搭建一个基础的神经网络并进行训练?

神经网络作为人工智能的重要组成部分,在图像处理、自然语言处理、语音识别、机器翻译等领域具有广泛的应用。今天的分享就来详细的介绍如何搭建一个简单的神经网络框架并进行训练。

首先定义一个名为NeuralNetwork的类,它继承了PyTorch框架的nn.Module类,用于创建神经网络。接下来部分是类的构造函数__init__(),用于初始化神经网络的各个层。在这个类中,初始化了以下层:

  • flatten:一个将输入展平的层。
  • hidden1:第一个隐藏层,输入大小为28x28(图像大小),输出大小为128。
  • hidden2:第二个隐藏层,输入大小为128,输出大小为128。
  • hidden3:第三个隐藏层,输入大小为128,输出大小为64。
  • out:输出层,输入大小为64,输出大小为10(类别数)。

    定义一个名为NeuralNetwork的类,它继承了PyTorch框架的nn.Module类,用于创建神经网络。classNeuralNetwork(nn.Module):def__init__(self):# 继承父类nn.Module的方法和属性super(NeuralNetwork, self).init()# 数据进行展平操作self.flatten=nn.Flatten()# 定义一层线性神经网络self.hidden1=nn.Linear(28*28,128)self.hidden2=nn.Linear(128,128)self.hidden3=nn.Linear(128,64)self.out=nn.Linear(64,10)

另外想自学机器学习深度学习,不知道如何入门的同学,还为大家整理了一份入门路线图(更新迭代不下10次),包含基础、理论、代码、实战项目、必读论文等等,希望可以帮到大家,大家可以添加小助手无偿自取****

这部分定义了前向传播方法forward(),通过前向传播计算输入数据x的输出。首先输入的数据先通过flatten层展平,然后依次经过隐藏层和激活函数进行线性变换和非线性处理。最后经过输出层输出预测结果x。(注意这个函数也是定义在NeuralNetwork之中的。)

复制代码
defforward(self,x):        # 对输入数据进行展平操作x=self.flatten(x)        # 将数据传入第一层线性神经网络x=self.hidden1(x)        # 对神经网络的输出应用ReLU激活函数x=torch.relu(x)x=self.hidden2(x)        # 对神经网络的输出应用sigmoid激活函数x=torch.sigmoid(x)x=self.hidden3(x)x=torch.relu(x)        # 将数据传入输出层x=self.out(x)        # 最后经过输出层输出预测结果xreturnx

这行代码创建了一个model的神经网络模型实例,并将其移动到特定的设备(我这里使用的是GPU)上进行计算。

复制代码
model=NeuralNetwork().to(device)

这行代码创建了一个CrossEntropyLoss(交叉熵损失函数)的实例,用作损失函数。​​​​​​​

复制代码
#交叉熵损失函数loss_fn=nn.CrossEntropyLoss()这行代码创建了一个Adam优化器的实例,将模型参数和学习率作为参数传入。​​​​​​​

# model.parameters()用于获取模型的所有参数。这些参数包括权重和偏差等# 它们都是Tensor类型的,是神经网络的重要组成部分optimizer=torch.optim.Adam(model.parameters(),lr=0.005)

这行代码定义了一个名为train的函数,用于进行训练。函数接受训练数据、模型、损失函数和优化器作为输入参数。​​​​​​​

复制代码
'''定义训练函数'''deftrain(dataloader,model,loss_fn,optimizer):# 设置模型为训练模式    model.train()# 记录优化次数    num=1# 遍历数据加载器中的每一个数据批次。for X,y in dataloader:        X,y=X.to(device),y.to(device)# 自动初始化权值w        pred=model.forward(X)        loss=loss_fn(pred,y) # 计算损失值# 将优化器的梯度缓存清零        optimizer.zero_grad()# 执行反向传播计算梯度        loss.backward()# 并通过优化器更新模型参数        optimizer.step()# 将损失值转换为标量        loss_value=loss.item()#将此次损失值打印出来        print(f'loss:{loss_value},[numbes]:{num}')#增加计数器num        num+=1

这行代码调用train函数来开始模型的训练。它传入训练数据、模型、损失函数和优化器,并执行训练过程。(注意训练前要先获取dataloader!)​​​​​​​

复制代码
#调用函数train(),传入训练数据,神经网络模型,损失函数和优化算法train(train_dataloader,model,loss_fn,optimizer)

完整代码展示:​​​​​​​

复制代码
'''定义神经网络'''classNeuralNetwork(nn.Module):def__init__(self):        super(NeuralNetwork, self).__init__()        self.flatten=nn.Flatten()        self.hidden1=nn.Linear(28*28,128)        self.hidden2=nn.Linear(128,128)        self.hidden3=nn.Linear(128,64)        self.out=nn.Linear(64,10)defforward(self,x):        x=self.flatten(x)        x=self.hidden1(x)        x=torch.relu(x)        x=self.hidden2(x)        x=torch.sigmoid(x)        x=self.hidden3(x)        x=torch.relu(x)        x=self.out(x)return xmodel=NeuralNetwork().to(device)print(model)'''建立损失函数和优化算法'''#交叉熵损失函数loss_fn=nn.CrossEntropyLoss()# 优化算法为随机梯度算法/Adam优化算法optimizer=torch.optim.Adam(model.parameters(),lr=0.005)'''定义训练函数'''deftrain(dataloader,model,loss_fn,optimizer):    model.train()# 记录优化次数    num=1for X,y in dataloader:        X,y=X.to(device),y.to(device)# 自动初始化权值w        pred=model.forward(X)        loss=loss_fn(pred,y) # 计算损失值        optimizer.zero_grad()        loss.backward()        optimizer.step()        loss_value=loss.item()        print(f'loss:{loss_value},[numbes]:{num}')        num+=1train(train_dataloader,model,loss_fn,optimizer)
相关推荐
Kay_Liang12 小时前
大语言模型如何精准调用函数—— Function Calling 系统笔记
java·大数据·spring boot·笔记·ai·langchain·tools
飞哥数智坊12 小时前
Claude Skills 实测体验:不用翻墙,GLM-4.6 也能玩转
人工智能·claude·chatglm (智谱)
程序员爱钓鱼12 小时前
Python编程实战 · 基础入门篇 | 变量与命名规范
后端·python
007php00712 小时前
猿辅导Java面试真实经历与深度总结(二)
java·开发语言·python·计算机网络·面试·职场和发展·golang
惊鸿.Jh12 小时前
C++可变参数模板
开发语言·python
FreeBuf_12 小时前
微软数字防御报告:AI成为新型威胁,自动化漏洞利用技术颠覆传统
人工智能·microsoft·自动化
MoRanzhi120312 小时前
Pillow 基础图像操作与数据预处理
图像处理·python·深度学习·机器学习·numpy·pillow·数据预处理
素素.陈12 小时前
向RAGFlow中上传文档到对应的知识库
开发语言·python
IT_陈寒12 小时前
Vue3性能优化实战:这7个技巧让我的应用加载速度提升50%!
前端·人工智能·后端
小宁爱Python12 小时前
Django Web 开发系列(一):视图基础与 URL 路由配置全解析
后端·python·django