Conv2d中groups=2时手动计算及pytorch源码验证

文章目录

  • [1. excel 原理计算](#1. excel 原理计算)
  • [2. pytorch 源码](#2. pytorch 源码)

1. excel 原理计算

2. pytorch 源码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    in_channels = 2
    out_channels = 2
    input_h = 3
    input_w = 3
    kernel_h = 2
    kernel_w = 2
    input_total = batch_size * in_channels * input_h * input_w
    input_shape = (batch_size, in_channels, input_w, input_h)
    ke_total = out_channels * in_channels * kernel_w * kernel_h
    ke_shape = (out_channels, in_channels, kernel_w, kernel_h)
    in_matrix = torch.arange(input_total).reshape(input_shape).to(torch.float)
    ke_matrix = torch.arange(ke_total).reshape(ke_shape).to(torch.float)
    print(f"in_matrix.shape=\n{in_matrix.shape}")
    print(f"ke_matrix.shape=\n{ke_matrix.shape}")
    print(f"in_matrix=\n{in_matrix}")
    print(f"ke_matrix=\n{ke_matrix}")
    my_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_h, groups=2,
                          bias=False)
    my_conv2d_weight = my_conv2d.weight
    test_weight = torch.arange(8).reshape((2, 1, 2, 2)).to(torch.float)
    my_conv2d.weight = nn.Parameter(test_weight)
   # print(f"my_conv2d_weight=\n{my_conv2d_weight}")
   # print(f"my_conv2d_weight.shape=\n{my_conv2d_weight.shape}")
    print(f"test_weight=\n{test_weight}")
    out_matrix = my_conv2d(in_matrix)
    print(f"out_matrix=\n{out_matrix}")
    print(f"out_matrix.shape={out_matrix.shape}")
  • 结果:
python 复制代码
in_matrix.shape=
torch.Size([2, 2, 3, 3])
ke_matrix.shape=
torch.Size([2, 2, 2, 2])
in_matrix=
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]],


        [[[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]],

         [[27., 28., 29.],
          [30., 31., 32.],
          [33., 34., 35.]]]])
ke_matrix=
tensor([[[[ 0.,  1.],
          [ 2.,  3.]],

         [[ 4.,  5.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [10., 11.]],

         [[12., 13.],
          [14., 15.]]]])
test_weight=
tensor([[[[0., 1.],
          [2., 3.]]],


        [[[4., 5.],
          [6., 7.]]]])
out_matrix=
tensor([[[[ 19.,  25.],
          [ 37.,  43.]],

         [[249., 271.],
          [315., 337.]]],


        [[[127., 133.],
          [145., 151.]],

         [[645., 667.],
          [711., 733.]]]], grad_fn=<ConvolutionBackward0>)
out_matrix.shape=torch.Size([2, 2, 2, 2])
相关推荐
说私域34 分钟前
基于开源AI智能名片链动2+1模式S2B2C商城小程序的超级文化符号构建路径研究
人工智能·小程序·开源
永洪科技36 分钟前
永洪科技荣获商业智能品牌影响力奖,全力打造”AI+决策”引擎
大数据·人工智能·科技·数据分析·数据可视化·bi
shangyingying_11 小时前
关于小波降噪、小波增强、小波去雾的原理区分
人工智能·深度学习·计算机视觉
小赖同学啊1 小时前
物联网数据安全区块链服务
开发语言·python·区块链
码荼2 小时前
学习开发之hashmap
java·python·学习·哈希算法·个人开发·小白学开发·不花钱不花时间crud
书玮嘎2 小时前
【WIP】【VLA&VLM——InternVL系列】
人工智能·深度学习
猫头虎2 小时前
猫头虎 AI工具分享:一个网页抓取、结构化数据提取、网页爬取、浏览器自动化操作工具:Hyperbrowser MCP
运维·人工智能·gpt·开源·自动化·文心一言·ai编程
要努力啊啊啊2 小时前
YOLOv2 正负样本分配机制详解
人工智能·深度学习·yolo·计算机视觉·目标跟踪
CareyWYR2 小时前
大模型真的能做推荐系统吗?ARAG论文给了我一个颠覆性的答案
人工智能