Conv2d中groups=2时手动计算及pytorch源码验证

文章目录

  • [1. excel 原理计算](#1. excel 原理计算)
  • [2. pytorch 源码](#2. pytorch 源码)

1. excel 原理计算

2. pytorch 源码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    in_channels = 2
    out_channels = 2
    input_h = 3
    input_w = 3
    kernel_h = 2
    kernel_w = 2
    input_total = batch_size * in_channels * input_h * input_w
    input_shape = (batch_size, in_channels, input_w, input_h)
    ke_total = out_channels * in_channels * kernel_w * kernel_h
    ke_shape = (out_channels, in_channels, kernel_w, kernel_h)
    in_matrix = torch.arange(input_total).reshape(input_shape).to(torch.float)
    ke_matrix = torch.arange(ke_total).reshape(ke_shape).to(torch.float)
    print(f"in_matrix.shape=\n{in_matrix.shape}")
    print(f"ke_matrix.shape=\n{ke_matrix.shape}")
    print(f"in_matrix=\n{in_matrix}")
    print(f"ke_matrix=\n{ke_matrix}")
    my_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_h, groups=2,
                          bias=False)
    my_conv2d_weight = my_conv2d.weight
    test_weight = torch.arange(8).reshape((2, 1, 2, 2)).to(torch.float)
    my_conv2d.weight = nn.Parameter(test_weight)
   # print(f"my_conv2d_weight=\n{my_conv2d_weight}")
   # print(f"my_conv2d_weight.shape=\n{my_conv2d_weight.shape}")
    print(f"test_weight=\n{test_weight}")
    out_matrix = my_conv2d(in_matrix)
    print(f"out_matrix=\n{out_matrix}")
    print(f"out_matrix.shape={out_matrix.shape}")
  • 结果:
python 复制代码
in_matrix.shape=
torch.Size([2, 2, 3, 3])
ke_matrix.shape=
torch.Size([2, 2, 2, 2])
in_matrix=
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]],


        [[[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]],

         [[27., 28., 29.],
          [30., 31., 32.],
          [33., 34., 35.]]]])
ke_matrix=
tensor([[[[ 0.,  1.],
          [ 2.,  3.]],

         [[ 4.,  5.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [10., 11.]],

         [[12., 13.],
          [14., 15.]]]])
test_weight=
tensor([[[[0., 1.],
          [2., 3.]]],


        [[[4., 5.],
          [6., 7.]]]])
out_matrix=
tensor([[[[ 19.,  25.],
          [ 37.,  43.]],

         [[249., 271.],
          [315., 337.]]],


        [[[127., 133.],
          [145., 151.]],

         [[645., 667.],
          [711., 733.]]]], grad_fn=<ConvolutionBackward0>)
out_matrix.shape=torch.Size([2, 2, 2, 2])
相关推荐
数研小生几秒前
Python自然语言处理:NLTK与Gensim库
开发语言·python·自然语言处理
weixin_509138342 分钟前
智能体认知动力学理论和实践
人工智能·智能体·语义空间·认知动力学
玄同7652 分钟前
机器学习中的三大距离度量:欧式距离、曼哈顿距离、切比雪夫距离详解
人工智能·深度学习·神经网络·目标检测·机器学习·自然语言处理·数据挖掘
第七序章2 分钟前
【Linux学习笔记】初识Linux —— 理解gcc编译器
linux·运维·服务器·开发语言·人工智能·笔记·学习
格林威4 分钟前
Baumer相机水果表皮瘀伤识别:实现无损品质分级的 7 个核心方法,附 OpenCV+Halcon 实战代码!
人工智能·opencv·计算机视觉·视觉检测·工业相机·sdk开发·堡盟相机
rainbow7242444 分钟前
AI证书选型深度分析:如何根据职业目标评估其真正价值
人工智能·机器学习
AI科技星6 分钟前
从ZUFT光速螺旋运动求导推出自然常数e
服务器·人工智能·线性代数·算法·矩阵
love530love9 分钟前
Windows 下 GCC 编译器安装与排错实录
人工智能·windows·python·gcc·msys2·gtk·msys2 mingw 64
倔强的石头1069 分钟前
归纳偏好 —— 机器学习的 “择偶标准”
人工智能·机器学习
zhangshuang-peta10 分钟前
通过MCP实现安全的多渠道人工智能集成
人工智能·ai agent·mcp·peta