前馈神经网络多分类任务

pytorch深度学习的套路都差不多,多看多想多写多测试,自然就会了。主要的技术还是在于背后的数学思想和数学逻辑。

废话不多说,上代码自己看。

python 复制代码
import torch
import numpy as np
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

class Network(nn.Module):
    def __init__(self ,input_dim ,hidden_dim ,out_dim):
        super().__init__()
        self.layer1 = nn.Sequential(  # 全连接层     [1, 28, 28]
            nn.Linear(784, 400),       # 输入维度,输出维度
            nn.BatchNorm1d(400),  # 批标准化,加快收敛,可不需要
            nn.ReLU()  				 # 激活函数
        )

        self.layer2 = nn.Sequential(
            nn.Linear(400, 200),
            nn.BatchNorm1d(200),
            nn.ReLU()
        )

        self.layer3 = nn.Sequential(   # 全连接层
            nn.Linear(200, 100),
            nn.BatchNorm1d(100),
            nn.ReLU()
        )

        self.layer4 = nn.Sequential(   # 最后一层为实际输出,不需要激活函数,因为有 10 个数字,所以输出维度为 10,表示10 类
            nn.Linear(100, 10),
        )

    def forward(self ,x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        output = self.layer4(x)
        return output
def get_num_correct(preds, labels):
    return (preds.argmax(dim=1) == labels).sum().item()

def dropout(x, keep_prob = 0.5):
    '''
    np.random.binomial 当输入二维数组时,按行按列(每个维度)都是按照给定概率生成1的个数,
比如 输入 10 * 6的矩阵,按照0.5的概率生成1 那么每列都大概会有5个1,每行大概会有3个1,
其实就不用考虑按行drop或者按列drop,相当于每行生成的mask都是不一样的,那么矩阵中每行的元素(代表一层中的神经元)都是按照不同的mask失活的
当矩阵形状改变行列代表的意义不一样时,由于每行每列(各个维度)的1的个数都是按照prob留存的,因此对结果没有影响。
    '''
    mask = torch.from_numpy(np.random.binomial(1,keep_prob,x.shape))
    return x * mask / keep_prob
    

if __name__ == "__main__":
    train_set = torchvision.datasets.MNIST(
        root='./data'
        , train=True
        , download=False
        , transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )
    test_set = torchvision.datasets.MNIST(
        root='./data',
        train=False,
        download=False,
        transform=transforms.Compose([
            transforms.ToTensor()])
    )

    train_loader = torch.utils.data.DataLoader(train_set
                                               , batch_size=512
                                               , shuffle=True
                                               )
    test_loader = torch.utils.data.DataLoader(test_set
                                              , batch_size=512
                                              , shuffle=True)

    net = Network(28 * 28, 256, 10)
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    epoch = 10

    for i in range(epoch):
        train_accur = 0.0
        train_loss = 0.0
        for batch in train_loader:
            images, labels = batch
            #images, labels = images.to(device), labels.to(device)
            images = images.squeeze(1).reshape(images.shape[0], -1)
            preds = net(images)
            optimizer.zero_grad()
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_accur += get_num_correct(preds, labels)
        print("loss :" + str(train_loss) + "train accur:" + str(train_accur * 1.0 / 60000))

    global correct
    with torch.no_grad():
        correct = 0
        for batch in test_loader:
            images, labels = batch
            #images, labels = images.to(device), labels.to(device)
            images = images.squeeze(1).reshape(-1, 784)
            preds = net(images)

            preds = preds.argmax(dim=1)
            correct += (preds == labels).sum()
            print(correct)
    print(correct.item() * 1.0 / len(test_set))
相关推荐
AL.千灯学长17 分钟前
DeepSeek接入Siri(已升级支持苹果手表)完整版硅基流动DeepSeek-R1部署
人工智能·gpt·ios·ai·苹果vision pro
LCG元1 小时前
大模型驱动的围术期质控系统全面解析与应用探索
人工智能
lihuayong1 小时前
计算机视觉:主流数据集整理
人工智能·计算机视觉·mnist数据集·coco数据集·图像数据集·cifar-10数据集·imagenet数据集
政安晨1 小时前
政安晨【零基础玩转各类开源AI项目】DeepSeek 多模态大模型Janus-Pro-7B,本地部署!支持图像识别和图像生成
人工智能·大模型·多模态·deepseek·janus-pro-7b
一ge科研小菜鸡1 小时前
DeepSeek 与后端开发:AI 赋能云端架构与智能化服务
人工智能·云原生
冰 河1 小时前
‌最新版DeepSeek保姆级安装教程:本地部署+避坑指南
人工智能·程序员·openai·deepseek·冰河大模型
维维180-3121-14551 小时前
AI赋能生态学暨“ChatGPT+”多技术融合在生态系统服务中的实践技术应用与论文撰写
人工智能·chatgpt
終不似少年遊*1 小时前
词向量与词嵌入
人工智能·深度学习·nlp·机器翻译·词嵌入
杜大哥2 小时前
如何在WPS打开的word、excel文件中,使用AI?
人工智能·word·excel·wps
Leiditech__2 小时前
人工智能时代电子机器人静电问题及电路设计防范措施
人工智能·嵌入式硬件·机器人·硬件工程