计算CNN卷积层和全连接层的参数量

计算CNN卷积层和全连接层的参数量

先前阅读

本文主旨意在搞明白2个问题:
第一个问题

一个卷积操作,他的参数,也就是我们要训练的参数,也就是我们说的权重,有多少个? 看到一个nn.Conv()函数,就能知道有多少个,它由那些因子决定的?

参数量是由以下3个因子决定的:

  • 卷积核大小(HxW)
  • 卷积核维度(D)
  • 卷积核有多少个

则卷积层的参数量为 卷积核大小(HxW) * 卷积核维度(D) * 卷积核有多少个

第二个问题

一个全连接操作,参数又有多少个?它由那些因子决定的?

  • 输入大小为 N
  • 输出大小为 M

则全连接层的参数量为 N×M

计算CNN卷积层的参数量

案例1

动态演示

看上图案例1的计算,输入图像为 5x5x1, 卷积核3x3x1, 输出3x3x1;

思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》1
  • 卷积核有多少个 ==》1

参数量为 3x3x1x1 = 9个

案例2

看上图案例2的计算,输入图像为 H1xW1x3, 卷积核3x3x3, 输出H2xW2x1;

思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》3
  • 卷积核有多少个 ==》1

参数量为 3x3x3x1 = 27个

从上面的两个案例可以看出, 参数量与输入图像的HxW没有关系, 参数量与输出图像的HxW也没有关系。

案例3

VGG-16为例,conv1-1,第一层

输入224x224x3, 输出是224x224x64,卷积核3x3

思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》3
  • 有多少个卷积核 ==》64

卷积核的维度是多少? 是由输入图像的维度决定,这里是3

卷积核的个数是多少? 是由输出图像的维度决定,这里是64

所以参数量 = 3x3x卷积核维度x卷积核个数 = 3x3x3x64 = 27个

Pytorch代码辅助理解

代码

bash 复制代码
nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

案例3中的卷积操作如下:

bash 复制代码
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

参数量计算: = kernel_size * kernel_size * in_channels(卷积核维度) * out_channels(卷积核个数) = 3 * 3 * 3 * 64

stride=1, padding=0, 这两个会影响到输出的HxW,上文已经提到和我们要计算的参数量无关。

最后,补上偏置参数,
每个卷积核都加个偏置 ,所以总得参数量:

参数量计算: = kernel_size * kernel_size * in_channels(卷积核维度) * out_channels(卷积核个数) + bias(=卷积核个数) = 3 * 3 * 3 * 64+64

计算FC全连接层的参数量

先看一段代码,这是我们经常看到的一段代码,先把x解析到1x9的维度,再做全连接操作

python 复制代码
self.fc = nn.Linear(9, 4)

x = x.view(-1, 9) # 把x,解析到1x9的维度,这一个操作是没有权重的
x = self.fc(x) # 做全连接操作

上面的代码对应的操作图,如下

图片来源 | Fully Connected Layer vs. Convolutional Layer: Explained

红色框的参数,就是我们要找的权重参数,有多少个?

思考问题?

  • 输入大小为 N = 9
  • 输出大小为 M =4

计算参数量 = 9x4 = 36个

再看对应的连接图

上图中的每一条连接线(橙色和蓝色的线),都有一个权重参数,共36条,所以有36个参数。

最后,补上偏置参数,

偏置参数数量: 每个输出节点有一个偏置项(bias),因此偏置参数的数量等于输出节点的数量,即 M=4

所以,总的参数数量为N×M+M = 40,即 M 为输出节点数量,N 为输入节点数量。

END


相关推荐
杭州泽沃电子科技有限公司1 小时前
为电气风险定价:如何利用监测数据评估工厂的“电气安全风险指数”?
人工智能·安全
Godspeed Zhao3 小时前
自动驾驶中的传感器技术24.3——Camera(18)
人工智能·机器学习·自动驾驶
顾北124 小时前
MCP协议实战|Spring AI + 高德地图工具集成教程
人工智能
wfeqhfxz25887824 小时前
毒蝇伞品种识别与分类_Centernet模型优化实战
人工智能·分类·数据挖掘
中杯可乐多加冰5 小时前
RAG 深度实践系列(七):从“能用”到“好用”——RAG 系统优化与效果评估
人工智能·大模型·llm·大语言模型·rag·检索增强生成
珠海西格电力科技5 小时前
微电网系统架构设计:并网/孤岛双模式运行与控制策略
网络·人工智能·物联网·系统架构·云计算·智慧城市
FreeBuf_5 小时前
AI扩大攻击面,大国博弈引发安全新挑战
人工智能·安全·chatgpt
weisian1516 小时前
进阶篇-8-数学篇-7--特征值与特征向量:AI特征提取的核心逻辑
人工智能·pca·特征值·特征向量·降维
Java程序员 拥抱ai6 小时前
撰写「从0到1构建下一代游戏AI客服」系列技术博客的初衷
人工智能
186******205316 小时前
AI重构项目开发全流程:效率革命与实践指南
人工智能·重构