计算CNN卷积层和全连接层的参数量

计算CNN卷积层和全连接层的参数量

先前阅读

本文主旨意在搞明白2个问题:
第一个问题

一个卷积操作,他的参数,也就是我们要训练的参数,也就是我们说的权重,有多少个? 看到一个nn.Conv()函数,就能知道有多少个,它由那些因子决定的?

参数量是由以下3个因子决定的:

  • 卷积核大小(HxW)
  • 卷积核维度(D)
  • 卷积核有多少个

则卷积层的参数量为 卷积核大小(HxW) * 卷积核维度(D) * 卷积核有多少个

第二个问题

一个全连接操作,参数又有多少个?它由那些因子决定的?

  • 输入大小为 N
  • 输出大小为 M

则全连接层的参数量为 N×M

计算CNN卷积层的参数量

案例1

动态演示

看上图案例1的计算,输入图像为 5x5x1, 卷积核3x3x1, 输出3x3x1;

思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》1
  • 卷积核有多少个 ==》1

参数量为 3x3x1x1 = 9个

案例2

看上图案例2的计算,输入图像为 H1xW1x3, 卷积核3x3x3, 输出H2xW2x1;

思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》3
  • 卷积核有多少个 ==》1

参数量为 3x3x3x1 = 27个

从上面的两个案例可以看出, 参数量与输入图像的HxW没有关系, 参数量与输出图像的HxW也没有关系。

案例3

VGG-16为例,conv1-1,第一层

输入224x224x3, 输出是224x224x64,卷积核3x3

思考3个参数:

  • 卷积核大小(HxW) ==》3x3
  • 卷积核维度(D) ==》3
  • 有多少个卷积核 ==》64

卷积核的维度是多少? 是由输入图像的维度决定,这里是3

卷积核的个数是多少? 是由输出图像的维度决定,这里是64

所以参数量 = 3x3x卷积核维度x卷积核个数 = 3x3x3x64 = 27个

Pytorch代码辅助理解

代码

bash 复制代码
nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

案例3中的卷积操作如下:

bash 复制代码
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

参数量计算: = kernel_size * kernel_size * in_channels(卷积核维度) * out_channels(卷积核个数) = 3 * 3 * 3 * 64

stride=1, padding=0, 这两个会影响到输出的HxW,上文已经提到和我们要计算的参数量无关。

最后,补上偏置参数,
每个卷积核都加个偏置 ,所以总得参数量:

参数量计算: = kernel_size * kernel_size * in_channels(卷积核维度) * out_channels(卷积核个数) + bias(=卷积核个数) = 3 * 3 * 3 * 64+64

计算FC全连接层的参数量

先看一段代码,这是我们经常看到的一段代码,先把x解析到1x9的维度,再做全连接操作

python 复制代码
self.fc = nn.Linear(9, 4)

x = x.view(-1, 9) # 把x,解析到1x9的维度,这一个操作是没有权重的
x = self.fc(x) # 做全连接操作

上面的代码对应的操作图,如下

图片来源 | Fully Connected Layer vs. Convolutional Layer: Explained

红色框的参数,就是我们要找的权重参数,有多少个?

思考问题?

  • 输入大小为 N = 9
  • 输出大小为 M =4

计算参数量 = 9x4 = 36个

再看对应的连接图

上图中的每一条连接线(橙色和蓝色的线),都有一个权重参数,共36条,所以有36个参数。

最后,补上偏置参数,

偏置参数数量: 每个输出节点有一个偏置项(bias),因此偏置参数的数量等于输出节点的数量,即 M=4

所以,总的参数数量为N×M+M = 40,即 M 为输出节点数量,N 为输入节点数量。

END


相关推荐
EasyCVR21 分钟前
GA/T1400视图库平台EasyCVR视频融合平台HLS视频协议是什么?
服务器·网络·人工智能·音视频
V搜xhliang024622 分钟前
基于深度学习的地物类型的提取
开发语言·人工智能·python·深度学习·神经网络·学习·conda
青椒大仙KI1136 分钟前
24/11/14 算法笔记<强化学习> 马尔可夫
人工智能·笔记·机器学习
GOTXX1 小时前
NAT、代理服务与内网穿透技术全解析
linux·网络·人工智能·计算机网络·智能路由器
进击的小小学生1 小时前
2024年第45周ETF周报
大数据·人工智能
TaoYuan__2 小时前
机器学习【激活函数】
人工智能·机器学习
TaoYuan__2 小时前
机器学习的常用算法
人工智能·算法·机器学习
正义的彬彬侠2 小时前
协方差矩阵及其计算方法
人工智能·机器学习·协方差·协方差矩阵
致Great2 小时前
Invar-RAG:基于不变性对齐的LLM检索方法提升生成质量
人工智能·大模型·rag
华奥系科技2 小时前
智慧安防丨以科技之力,筑起防范人贩的铜墙铁壁
人工智能·科技·安全·生活