Conv2d中groups=2时手动计算及pytorch源码验证

文章目录

  • [1. excel 原理计算](#1. excel 原理计算)
  • [2. pytorch 源码](#2. pytorch 源码)

1. excel 原理计算

2. pytorch 源码

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    batch_size = 2
    in_channels = 2
    out_channels = 2
    input_h = 3
    input_w = 3
    kernel_h = 2
    kernel_w = 2
    input_total = batch_size * in_channels * input_h * input_w
    input_shape = (batch_size, in_channels, input_w, input_h)
    ke_total = out_channels * in_channels * kernel_w * kernel_h
    ke_shape = (out_channels, in_channels, kernel_w, kernel_h)
    in_matrix = torch.arange(input_total).reshape(input_shape).to(torch.float)
    ke_matrix = torch.arange(ke_total).reshape(ke_shape).to(torch.float)
    print(f"in_matrix.shape=\n{in_matrix.shape}")
    print(f"ke_matrix.shape=\n{ke_matrix.shape}")
    print(f"in_matrix=\n{in_matrix}")
    print(f"ke_matrix=\n{ke_matrix}")
    my_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_h, groups=2,
                          bias=False)
    my_conv2d_weight = my_conv2d.weight
    test_weight = torch.arange(8).reshape((2, 1, 2, 2)).to(torch.float)
    my_conv2d.weight = nn.Parameter(test_weight)
   # print(f"my_conv2d_weight=\n{my_conv2d_weight}")
   # print(f"my_conv2d_weight.shape=\n{my_conv2d_weight.shape}")
    print(f"test_weight=\n{test_weight}")
    out_matrix = my_conv2d(in_matrix)
    print(f"out_matrix=\n{out_matrix}")
    print(f"out_matrix.shape={out_matrix.shape}")
  • 结果:
python 复制代码
in_matrix.shape=
torch.Size([2, 2, 3, 3])
ke_matrix.shape=
torch.Size([2, 2, 2, 2])
in_matrix=
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]],


        [[[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]],

         [[27., 28., 29.],
          [30., 31., 32.],
          [33., 34., 35.]]]])
ke_matrix=
tensor([[[[ 0.,  1.],
          [ 2.,  3.]],

         [[ 4.,  5.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [10., 11.]],

         [[12., 13.],
          [14., 15.]]]])
test_weight=
tensor([[[[0., 1.],
          [2., 3.]]],


        [[[4., 5.],
          [6., 7.]]]])
out_matrix=
tensor([[[[ 19.,  25.],
          [ 37.,  43.]],

         [[249., 271.],
          [315., 337.]]],


        [[[127., 133.],
          [145., 151.]],

         [[645., 667.],
          [711., 733.]]]], grad_fn=<ConvolutionBackward0>)
out_matrix.shape=torch.Size([2, 2, 2, 2])
相关推荐
美酒没故事°1 天前
Open WebUI安装指南。搭建自己的自托管 AI 平台
人工智能·windows·ai
云烟成雨TD1 天前
Spring AI Alibaba 1.x 系列【6】ReactAgent 同步执行 & 流式执行
java·人工智能·spring
Csvn1 天前
🌟 LangChain 30 天保姆级教程 · Day 13|OutputParser 进阶!让 AI 输出自动转为结构化对象,并支持自动重试!
python·langchain
AI攻城狮1 天前
用 Obsidian CLI + LLM 构建本地 RAG:让你的笔记真正「活」起来
人工智能·云原生·aigc
鸿乃江边鸟1 天前
Nanobot 从onboard启动命令来看个人助理Agent的实现
人工智能·ai
lpfasd1231 天前
基于Cloudflare生态的应用部署与开发全解
人工智能·agent·cloudflare
俞凡1 天前
DevOps 2.0:智能体如何接管故障修复和基础设施维护
人工智能
comedate1 天前
[OpenClaw] GLM 5 关于电影 - 人工智能 - 的思考
人工智能·电影评价
财迅通Ai1 天前
6000万吨产能承压 卫星化学迎来战略窗口期
大数据·人工智能·物联网·卫星化学
liliangcsdn1 天前
Agent Memory智能体记忆系统的示例分析
数据库·人工智能·全文检索