Conv2d中groups=2时手动计算及pytorch源码验证

文章目录

  • [1. excel 原理计算](#1. excel 原理计算)
  • [2. pytorch 源码](#2. pytorch 源码)

1. excel 原理计算

2. pytorch 源码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    in_channels = 2
    out_channels = 2
    input_h = 3
    input_w = 3
    kernel_h = 2
    kernel_w = 2
    input_total = batch_size * in_channels * input_h * input_w
    input_shape = (batch_size, in_channels, input_w, input_h)
    ke_total = out_channels * in_channels * kernel_w * kernel_h
    ke_shape = (out_channels, in_channels, kernel_w, kernel_h)
    in_matrix = torch.arange(input_total).reshape(input_shape).to(torch.float)
    ke_matrix = torch.arange(ke_total).reshape(ke_shape).to(torch.float)
    print(f"in_matrix.shape=\n{in_matrix.shape}")
    print(f"ke_matrix.shape=\n{ke_matrix.shape}")
    print(f"in_matrix=\n{in_matrix}")
    print(f"ke_matrix=\n{ke_matrix}")
    my_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_h, groups=2,
                          bias=False)
    my_conv2d_weight = my_conv2d.weight
    test_weight = torch.arange(8).reshape((2, 1, 2, 2)).to(torch.float)
    my_conv2d.weight = nn.Parameter(test_weight)
   # print(f"my_conv2d_weight=\n{my_conv2d_weight}")
   # print(f"my_conv2d_weight.shape=\n{my_conv2d_weight.shape}")
    print(f"test_weight=\n{test_weight}")
    out_matrix = my_conv2d(in_matrix)
    print(f"out_matrix=\n{out_matrix}")
    print(f"out_matrix.shape={out_matrix.shape}")
  • 结果:
python 复制代码
in_matrix.shape=
torch.Size([2, 2, 3, 3])
ke_matrix.shape=
torch.Size([2, 2, 2, 2])
in_matrix=
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]],


        [[[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]],

         [[27., 28., 29.],
          [30., 31., 32.],
          [33., 34., 35.]]]])
ke_matrix=
tensor([[[[ 0.,  1.],
          [ 2.,  3.]],

         [[ 4.,  5.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [10., 11.]],

         [[12., 13.],
          [14., 15.]]]])
test_weight=
tensor([[[[0., 1.],
          [2., 3.]]],


        [[[4., 5.],
          [6., 7.]]]])
out_matrix=
tensor([[[[ 19.,  25.],
          [ 37.,  43.]],

         [[249., 271.],
          [315., 337.]]],


        [[[127., 133.],
          [145., 151.]],

         [[645., 667.],
          [711., 733.]]]], grad_fn=<ConvolutionBackward0>)
out_matrix.shape=torch.Size([2, 2, 2, 2])
相关推荐
马拉AI2 分钟前
解锁Nature发文小Tips:LSTM、CNN与Attention的创新融合之路
人工智能·cnn·lstm
sufu10653 分钟前
SpringAI更新:废弃tools方法、正式支持DeepSeek!
人工智能·后端
知舟不叙18 分钟前
基于OpenCV中的图像拼接方法详解
人工智能·opencv·计算机视觉·图像拼接
Jamence21 分钟前
多模态大语言模型arxiv论文略读(七十五)
人工智能·语言模型·自然语言处理
Code_流苏24 分钟前
《Python星球日记》 第71天:命名实体识别(NER)与关系抽取
python·深度学习·ner·预训练语言模型·关系抽取·统计机器学习·标注方式
点云SLAM24 分钟前
Python中列表(list)知识详解(2)和注意事项以及应用示例
开发语言·人工智能·python·python学习·数据结果·list数据结果
放飞自我的Coder24 分钟前
【NLP 计算句子之间的BLEU和ROUGE分数】
人工智能·自然语言处理
国强_dev24 分钟前
任意复杂度的 JSON 数据转换为多个结构化的 Pandas DataFrame 表格
开发语言·python
小众AI26 分钟前
UI-TARS: 基于视觉语言模型的多模式代理
人工智能·ui·语言模型
北京地铁1号线43 分钟前
卷积神经网络(CNN)前向传播手撕
人工智能·pytorch·深度学习