如何通过神经网络看模型参数量?

我们经常听说某某大模型有多少亿参数,比如 DeepSeek R1 671B,那么这个参数如何通过神经网络图来看?

  • B(billion):10 亿
  • 例如:7B:就是 70 亿参数,671B 就是 6710 亿参数
  • 从神经网络神经元连接图上来看,这个参数量怎么来的?

以下面这个神经网络为例:

4 个输入(实际上 3 个变量输入一个偏置),2 个输出,三层网络,2 个隐藏层。

第1个隐藏层:使用Xavier正态分布初始化权重,激活函数使用Tanh

第2个隐藏层:使用He正态分布初始化权重,激活函数使用ReLU

输出层:按默认方式初始化,激活函数使用Softmax

先看如下代码,看一下输出结果:

复制代码
import torch
import torch.nn as nn

class Model(nn.Module):
    # 初始化
    def __init__(self):
        super(Model, self).__init__()  # 调用父类初始化
        self.linear1 = nn.Linear(3, 4)  # 第1个隐藏层,3个输入,4个输出
        nn.init.xavier_normal_(self.linear1.weight)  # 初始化权重参数
        self.linear2 = nn.Linear(4, 4)  # 第2个隐藏层,4个输入,4个输出
        nn.init.kaiming_normal_(self.linear2.weight)  # 初始化权重参数
        self.out = nn.Linear(4, 2)  # 输出层,4个输入,2个输出,默认使用He均匀分布初始化

    # 前向传播
    def forward(self, x):
        x = self.linear1(x)  # 经过第1个隐藏层
        x = torch.tanh(x)  # 激活函数
        x = self.linear2(x)  # 经过第2个隐藏层
        x = torch.relu(x)  # 激活函数
        x = self.out(x)  # 经过输出层
        x = torch.softmax(x, dim=1)  # 激活函数
        return x

model = Model()
output = model(torch.randn(10, 3))
print("输出:\n", output)
print()

# 使用named_parameters()查看各层参数
print("模型参数:")
for name, param in model.named_parameters():
    print(name, param)
    print()

# 使用state_dict()查看各层参数
print("模型参数:\n", model.state_dict())


from torchsummary import summary
# input_size:特征数,batch_size:样本数
summary(model, input_size=(3,), batch_size=10, device="cpu")

为什么输出的 param 是 16,20,10?

  • 第一层:3 * 4 + 4 = 16
  • 第二层:4 * 4 + 4 = 20
  • 第三层:4 * 2 + 2 = 10

总参数:16 + 20 + 10 = 46个参数量

* 前后的数实际上就是当前层的组合数

**+**后面的数就是偏置的数量

相关推荐
微爱帮监所写信寄信几秒前
微爱帮监狱寄信写信工具用户头像安全审核体系
人工智能
熬夜敲代码的小N1 分钟前
AI文本分类实战:从数据预处理到模型部署全流程解析
人工智能·分类·数据挖掘
沛沛老爹2 分钟前
Web开发者快速上手AI Agent:Dify本地化部署与提示词优化实战
前端·人工智能·rag·faq·文档细粒度
国科安芯2 分钟前
低轨卫星边缘计算节点的抗辐照MCU选型分析
人工智能·单片机·嵌入式硬件·架构·边缘计算·安全威胁分析·安全性测试
美团技术团队3 分钟前
2025 美团技术团队热门技术文章汇总
人工智能
GEO AI搜索优化助手3 分钟前
生成式AI搜索的跨行业革命与商业模式重构
大数据·人工智能·搜索引擎·重构·生成式引擎优化·ai优化·geo搜索优化
张拭心8 分钟前
"氛围编程"程序员被解雇了
android·前端·人工智能
我是人机不吃鸭梨14 分钟前
Flutter AI 集成革命(2025版):从 Gemini 模型到智能表单验证器的终极方案
开发语言·javascript·人工智能·flutter·microsoft·架构
编码小哥14 分钟前
OpenCV阈值分割技术:全局阈值与自适应阈值
人工智能·opencv·计算机视觉
计算机徐师兄14 分钟前
Python基于Django的网络入侵检测系统(附源码,文档说明)
python·django·网络入侵检测·网络入侵检测系统·python网络入侵检测系统·网络入侵·python网络入侵检测