《动手学深度学习 Pytorch版》 4.3 多层感知机的简洁实现

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

模型

python 复制代码
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),  # 与 3.7 节相比多了一层
                    nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:  # 使用正态分布中的随机值初始化权重
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)
复制代码
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=10, bias=True)
)
python 复制代码
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)


练习

(1)尝试添加不同数量的隐藏层(也可以修改学习率),怎样设置效果最好?

python 复制代码
net2 = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 128),
                    nn.ReLU(),
                    nn.Linear(128, 10))

def init_weights(m):
    if type(m) == nn.Linear:  # 使用正态分布中的随机值初始化权重
        nn.init.normal_(m.weight, std=0.01)

net2.apply(init_weights)

batch_size2, lr2, num_epochs2 = 256, 0.3, 10
loss2 = nn.CrossEntropyLoss(reduction='none')
trainer2 = torch.optim.SGD(net2.parameters(), lr=lr2)

train_iter2, test_iter2 = d2l.load_data_fashion_mnist(batch_size2)
d2l.train_ch3(net2, train_iter2, test_iter2, loss2, num_epochs2, trainer2)



(2)尝试不同的激活函数,哪个激活函数效果最好?

python 复制代码
net3 = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.Sigmoid(),
                    nn.Linear(256, 10))

net4 = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.Tanh(),
                    nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net3.apply(init_weights)
net4.apply(init_weights)


train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
python 复制代码
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net3.parameters(), lr=lr)
d2l.train_ch3(net3, train_iter, test_iter, loss, num_epochs, trainer)
复制代码
---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

Cell In[5], line 4
      2 loss = nn.CrossEntropyLoss(reduction='none')
      3 trainer = torch.optim.SGD(net3.parameters(), lr=lr)
----> 4 d2l.train_ch3(net3, train_iter, test_iter, loss, num_epochs, trainer)


File c:\Software\Miniconda3\envs\d2l\lib\site-packages\d2l\torch.py:340, in train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)
    338     animator.add(epoch + 1, train_metrics + (test_acc,))
    339 train_loss, train_acc = train_metrics
--> 340 assert train_loss < 0.5, train_loss
    341 assert train_acc <= 1 and train_acc > 0.7, train_acc
    342 assert test_acc <= 1 and test_acc > 0.7, test_acc


AssertionError: 0.5017133202234904
python 复制代码
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net4.parameters(), lr=lr)
d2l.train_ch3(net4, train_iter, test_iter, loss, num_epochs, trainer)


还是 ReLU 比较奈斯。


(3)尝试不同的方案来初始化权重,什么方案效果最好。

累了,不想试试了。略...

相关推荐
茶杯67510 分钟前
GraphRAG产品赋能企业智能升级:创邻科技知寰Hybrid RAG的四大核心应用场景深度解析
人工智能·科技·graphrag产品
少林and叔叔13 分钟前
基于yolov5.7.0的人工智能算法的下载、开发环境搭建(pycharm)与运行测试
人工智能·pytorch·python·yolo·目标检测·pycharm
kuan_li_lyg30 分钟前
笛卡尔坐标机器人控制的虚拟前向动力学模型
人工智能·stm32·机器人·机械臂·动力学·运动学·导纳控制
合作小小程序员小小店35 分钟前
旧版本附近停车场推荐系统demo,基于python+flask+协同推荐(基于用户信息推荐),开发语言python,数据库mysql,
人工智能·python·flask·sklearn·推荐算法
却道天凉_好个秋42 分钟前
OpenCV(十四):绘制直线
人工智能·opencv·计算机视觉
动能小子ohhh44 分钟前
Langchain从零开始到应用落地案例[AI智能助手]【3】---使用Paddle-OCR识别优化可识别图片进行解析回答
人工智能·python·pycharm·langchain·ocr·paddle·1024程序员节
IT_陈寒1 小时前
Vue 3.4性能优化实战:5个鲜为人知的Composition API技巧让打包体积减少40%
前端·人工智能·后端
数据与人工智能律师1 小时前
数据淘金时代的法治罗盘:合法收集、使用与变现数据的边界与智慧
大数据·网络·人工智能·云计算·区块链
王哈哈^_^1 小时前
【数据集】【YOLO】【目标检测】建筑垃圾数据集 4256 张,YOLO建筑垃圾识别算法实战训推教程。
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·数据集
牛奶还是纯的好1 小时前
双目测距实战4-自标定
人工智能·3d视觉