Pytorch:问题整理

目录

一、神经网络输出维度如何设计

二、数据集对象的batch_size设置多少合适

三、Linear的作用

四、为什么要梯度清零

五、为什么要将神经元输出压缩到一定范围


一、神经网络输出维度如何设计

维度从大→小逐步收缩,先提取特征再压缩降维,最后对齐任务输出;中间层靠经验 + 试错,遵循「金字塔结构」

1、各层通用原则

  • 输入层:维度 = 原始特征数

    • 表格数据:样本特征数(比如 10 个特征→输入维度 10)
  • 输出层:维度严格由任务决定

    • 二分类:1 维(sigmoid)
    • 多分类 N 类:N 维(softmax)
    • 回归任务:1 维(预测一个值)
    • 多输出回归:要预测几个值就几维
  • 隐藏层通用规则

    • 整体结构:递减金字塔 输入维度 ≥ 第一层隐藏维 > 下一层 > ... > 输出维
    • 不要忽大忽小、不要中间突然暴增
    • 维度不能过小:太小会欠拟合、特征丢失
    • 维度不能过大:太大会过拟合、参数量爆炸、训练慢

2、如何快速调优

  • ✅先搭 baseline:按「逐层减半」设初始维度
  • ✅看训练效果:
    • 训练集、测试集都差 → 加大每层维度、加层数
    • 训练集好、测试集差(过拟合)→ 减小维度、减层数、加正则 / Dropout
  • ✅优先调宽度(维度) 再调深度(层数)

二、数据集对象的batch_size设置多少合适

1、batch_size的功能

BatchSize :每次迭代一次性喂给模型的样本数

  • 越小:梯度越抖动、泛化更好、显存占用低、训练慢
  • 越大:梯度越平稳、训练更快、显存占用高、容易泛化变差

2、万能设置参考

  • 先看显卡显存:从 32 开始往上试,不爆显存为止
  • 小数据集固定选:8 / 16
  • 中等数据集:32 / 64
  • 大数据集:64 / 128
  • 调完 batch 后,同步按比例微调学习率
  • 观察:
    • 训练 loss 震荡大 → 适当加大 batch
    • 训练好、验证差(过拟合)→ 减小 batch、降低学习率

小数据小 batch,大数据大 batch;显存卡死上限,batch 翻倍 lr 加倍;BN 别用太小,小 batch 泛化更好。

三、Linear的作用

改变维度、融合特征、做线性映射、提供可学习参数;必须配合激活函数才有非线性拟合能力。

四、为什么要梯度清零

因为 PyTorch 会默认把每一步的梯度 累加(叠加)起来,而不是替换。如果你不清零,梯度会越堆越多,模型直接训练崩掉。

五、为什么要将神经元输出压缩到一定范围

把神经元输出限制在一定区间(如 Sigmoid→0~1、Tanh→-1~1、ReLU→0~+∞),

根本原因就 5 点:

  1. 防止数值爆炸、权重失控

如果不用激活函数,多层线性叠加后,数值会无限变大 / 变小 ,权重更新时梯度直接溢出、NaN 报错,模型完全没法训练。压缩值域能把每一层输出锁在可控范围,数值稳定

  1. 引入非线性,解决线性模型局限

没有激活函数,无论多少层网络都等价单层线性回归 ,只能拟合直线。压缩值域的同时做非线性映射,让网络能拟合复杂曲线、语义、图像等非线性关系。

  1. 控制梯度大小,缓解梯度消失 / 爆炸
  • 把输出压在有限区间,导数(梯度)也被限制在小范围;

  • 避免梯度过大爆炸、或过小趋近于 0,利于深层网络反向传播

  1. 归一化特征,加速收敛

把每层输出统一到相近值域(如 0~1、-1~1),梯度下降时步长更稳定、参数更新更平滑,训练更快、更容易收敛

  1. 符合生物学神经元特性

人脑神经元兴奋度本身就有饱和阈值 :刺激弱不响应、中等线性响应、过强达到饱和不再增加,激活函数压缩值域就是模拟这种饱和激活特性

相关推荐
花酒锄作田1 小时前
[python]argparse 包在聊天机器人中的应用
python
久违 °4 小时前
【AI-Agent】TagMatrix 数据标注工具开发
人工智能·数据分析·go·agent·数据隐私
NiceCloud喜云4 小时前
Opus 4.8 的 Effort Control 怎么选:Low 到 Max 五档策略
android·java·大数据·前端·c++·python·spring
AI360labs_atyun4 小时前
腾讯推出电子牛马Marvis,好用吗?
人工智能·科技·ai
Dfreedom.4 小时前
Windows、虚拟机、开发板组网通信原理及调试通联步骤
人工智能·windows·部署·边缘计算·开发板·模型加速
3DVisionary4 小时前
蓝光三维扫描:医疗制造的精度焦虑怎么解
人工智能·算法·制造·蓝光三维扫描·医疗制造·三维检测·义齿检测
Are_You_Okkk_4 小时前
基于MonkeyCode解析AI研发新模式,根治开发低效痛点
大数据·人工智能·开源·ai编程
AI玫瑰助手4 小时前
Python函数:默认参数的定义与注意事项
开发语言·python·信息可视化
好评笔记4 小时前
机器学习面试八股——常用损失函数
人工智能·深度学习·算法·机器学习·校招
weixin_468466854 小时前
全局与局部注意力机制新手实战指南
人工智能·python·深度学习·算法·自然语言处理·transformer·注意力机制