Conv2d中groups=2时手动计算及pytorch源码验证

文章目录

  • [1. excel 原理计算](#1. excel 原理计算)
  • [2. pytorch 源码](#2. pytorch 源码)

1. excel 原理计算

2. pytorch 源码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    in_channels = 2
    out_channels = 2
    input_h = 3
    input_w = 3
    kernel_h = 2
    kernel_w = 2
    input_total = batch_size * in_channels * input_h * input_w
    input_shape = (batch_size, in_channels, input_w, input_h)
    ke_total = out_channels * in_channels * kernel_w * kernel_h
    ke_shape = (out_channels, in_channels, kernel_w, kernel_h)
    in_matrix = torch.arange(input_total).reshape(input_shape).to(torch.float)
    ke_matrix = torch.arange(ke_total).reshape(ke_shape).to(torch.float)
    print(f"in_matrix.shape=\n{in_matrix.shape}")
    print(f"ke_matrix.shape=\n{ke_matrix.shape}")
    print(f"in_matrix=\n{in_matrix}")
    print(f"ke_matrix=\n{ke_matrix}")
    my_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_h, groups=2,
                          bias=False)
    my_conv2d_weight = my_conv2d.weight
    test_weight = torch.arange(8).reshape((2, 1, 2, 2)).to(torch.float)
    my_conv2d.weight = nn.Parameter(test_weight)
   # print(f"my_conv2d_weight=\n{my_conv2d_weight}")
   # print(f"my_conv2d_weight.shape=\n{my_conv2d_weight.shape}")
    print(f"test_weight=\n{test_weight}")
    out_matrix = my_conv2d(in_matrix)
    print(f"out_matrix=\n{out_matrix}")
    print(f"out_matrix.shape={out_matrix.shape}")
  • 结果:
python 复制代码
in_matrix.shape=
torch.Size([2, 2, 3, 3])
ke_matrix.shape=
torch.Size([2, 2, 2, 2])
in_matrix=
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]],


        [[[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]],

         [[27., 28., 29.],
          [30., 31., 32.],
          [33., 34., 35.]]]])
ke_matrix=
tensor([[[[ 0.,  1.],
          [ 2.,  3.]],

         [[ 4.,  5.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [10., 11.]],

         [[12., 13.],
          [14., 15.]]]])
test_weight=
tensor([[[[0., 1.],
          [2., 3.]]],


        [[[4., 5.],
          [6., 7.]]]])
out_matrix=
tensor([[[[ 19.,  25.],
          [ 37.,  43.]],

         [[249., 271.],
          [315., 337.]]],


        [[[127., 133.],
          [145., 151.]],

         [[645., 667.],
          [711., 733.]]]], grad_fn=<ConvolutionBackward0>)
out_matrix.shape=torch.Size([2, 2, 2, 2])
相关推荐
kunge1v51 小时前
学习爬虫第三天:数据提取
前端·爬虫·python·学习
爱学习的小鱼gogo1 小时前
python 矩阵中寻找就接近的目标值 (矩阵-中等)含源码(八)
开发语言·经验分享·python·算法·职场和发展·矩阵
聚客AI1 小时前
系统提示的“消亡”?上下文工程正在重新定义人机交互规则
图像处理·人工智能·pytorch·语言模型·自然语言处理·chatgpt·gpt-3
Hello.Reader1 小时前
Flink 状态模式演进(State Schema Evolution)从原理到落地的一站式指南
python·flink·状态模式
红纸2811 小时前
Subword算法之WordPiece、Unigram与SentencePiece
人工智能·python·深度学习·神经网络·算法·机器学习·自然语言处理
golang学习记1 小时前
Crush:新一代基于Go语言构建的开源 AI 编程CLI工具
人工智能
一车小面包1 小时前
Subword-Based Tokenization策略之BPE与BBPE
人工智能·自然语言处理
红纸2811 小时前
Subword分词方法的BPE与BBPE
人工智能·python·深度学习·神经网络·自然语言处理
zy_destiny1 小时前
【工业场景】用YOLOv8实现反光衣识别
人工智能·python·yolo·机器学习·计算机视觉
zhangjipinggom1 小时前
QwenVL - 202310版-论文阅读
人工智能·深度学习