动手学深度学习:多层感知机

多层感知机

之前学习了线性模型,我们可以很容易想到,现实生活中很多的现象用线性是无法拟合的,所以研究人员就想到了在线性之后添加一些非线性的激活函数,使得整个网络获得非线性,以拟合更加复杂的情况。

常见的激活函数

数学公式自行查找

ReLU

python 复制代码
import torch
import matplotlib.pyplot as plt
x=torch.arange(-10.0,10.0,0.1,requires_grad=True)
x
python 复制代码
y=torch.relu(x)
y
可视化图像
python 复制代码
plt.plot(x.detach(),y.detach())
可视化梯度
python 复制代码
y.backward(torch.ones_like(x),retain_graph=True)
plt.clf()
plt.plot(x.detach(),x.grad)

sigmoid

python 复制代码
y=torch.sigmoid(x)
y
可视化图像
python 复制代码
plt.clf()
plt.plot(x.detach(),y.detach())
可视化梯度
python 复制代码
x.grad.data.zero_()
y.backward(torch.ones_like(x),retain_graph=True)
plt.clf()
plt.plot(x.detach(),x.grad)

tanh

python 复制代码
y=torch.tanh(x)
y
可视化图像
python 复制代码
plt.clf()
plt.plot(x.detach(),y.detach())
可视化梯度
python 复制代码
x.grad.data.zero_()
y.backward(torch.ones_like(x),retain_graph=True)
plt.clf()
plt.plot(x.detach(),x.grad)

读取数据

我上一篇博客写过,点击直接跳转

模型(ReLU激活+权重衰减+dropout)

上一篇博客train和test的acc中间就有1%的差距,似乎是有一点过拟合,在添加ReLU激活函数将线性模型变成多层感知机之后(整体准确率提高1%)会发现两者之间仍然有差距,所以我尝试了权重衰减和dropout。

权重衰减0.01会发现两者差距几乎没有,但是train的acc降低了1%,几乎跟线性效果差不多了,所以不可取。

dropout0.3会发现两者几乎重合,train下降很少,同时test也有所上升。

python 复制代码
from torch import nn
net=nn.Sequential(nn.Flatten(),
                  nn.Linear(28*28,256),
                  nn.ReLU(),
                  nn.Dropout(0.3),
                  nn.Linear(256,8))
def init_weight(m):
    if type(m)==nn.Linear:
        nn.init.normal_(m.weight,std=0.01)
net.apply(init_weight)

权重衰退就是在优化器传参的时候对Linear()的参数进行限制。

python 复制代码
wd=0.00 #权重衰减的值
loss_fn=nn.CrossEntropyLoss()
optimer=torch.optim.SGD([{"params": net[1].weight,"weight_decay": wd},{"params": net[1].bias},{"params": net[4].weight,"weight_decay": wd},{"params": net[4].bias}],lr=0.1)
python 复制代码
epochs_num=10
train_len=len(train_iter.dataset)
all_acc=[]
all_loss=[]
test_all_acc=[]
for epoch in range(epochs_num):
    acc=0
    loss=0
    for x,y in train_iter:
        hat_y=net(x)
        l=loss_fn(hat_y,y)
        loss+=l
        optimer.zero_grad()
        l.backward()
        optimer.step()
        acc+=(hat_y.argmax(1)==y).sum()
    all_acc.append(acc/train_len)
    all_loss.append(loss.detach().numpy())
    test_acc=0
    test_len=len(test_iter.dataset)
    with torch.no_grad():
        for x,y in test_iter:
            hat_y=net(x)
            test_acc+=(hat_y.argmax(1)==y).sum()
    test_all_acc.append(test_acc/test_len)
    print(f'{epoch}的test的acc{test_acc/test_len}')

可视化

python 复制代码
import matplotlib.pyplot as plt

损失函数可视化

python 复制代码
plt.plot(range(1,epochs_num+1),all_loss,'.-',label='train_loss')
plt.text(epochs_num, all_loss[-1], f'{all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')

准确率可视化

python 复制代码
plt.plot(range(1,epochs_num+1),all_acc,'-',label='train_acc')
plt.text(epochs_num, all_acc[-1], f'{all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_acc,'-.',label='test_acc')
plt.legend()

预测结果

python 复制代码
with torch.no_grad():
    all_num=5
    index=1
    plt.figure(figsize=(12,5))
    for i,label in zip(test_data_path,test_labels):
        if index<=all_num:
            img=cv2.imread(i)
            input_img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
            img=cv2.cvtColor(input_img,cv2.COLOR_BGR2RGB)
            input_img=cv2.resize(input_img,size) 
            input_img=transforms.ToTensor()(input_img)
            result=net(input_img).argmax(1)
            plt.subplot(1,all_num,index)
            plt.imshow(img)
            plt.title(f'true{label},predict{result.detach().numpy()}')
            plt.axis("off")
            index+=1
相关推荐
终端域名几秒前
当今前沿科技:脑机共生界面(脑机接口)深度解析
人工智能·智能电视
汗流浃背了吧,老弟!几秒前
预训练语言模型(Pre-trained Language Model, PLM)介绍
深度学习·语言模型·自然语言处理
化作星辰44 分钟前
深度学习_神经网络激活函数
人工智能·深度学习·神经网络
陈天伟教授1 小时前
人工智能技术- 语音语言- 03 ChatGPT 对话、写诗、写小说
人工智能·chatgpt
llilian_161 小时前
智能数字式毫秒计在实际生活场景中的应用 数字式毫秒计 智能毫秒计
大数据·网络·人工智能
打码人的日常分享1 小时前
基于信创体系政务服务信息化建设方案(PPT)
大数据·服务器·人工智能·信息可视化·架构·政务
硬汉嵌入式2 小时前
专为 MATLAB 优化的 AI 助手MATLAB Copilot
人工智能·matlab·copilot
北京盛世宏博2 小时前
如何利用技术手段来甄选一套档案馆库房安全温湿度监控系统
服务器·网络·人工智能·选择·档案温湿度
搞科研的小刘选手2 小时前
【EI稳定】检索第六届大数据经济与信息化管理国际学术会议(BDEIM 2025)
大数据·人工智能·经济
半吊子全栈工匠2 小时前
软件产品的10个UI设计技巧及AI 辅助
人工智能·ui