python
复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
batch_size = 2
in_channels = 2
out_channels = 2
input_h = 3
input_w = 3
kernel_h = 2
kernel_w = 2
input_total = batch_size * in_channels * input_h * input_w
input_shape = (batch_size, in_channels, input_w, input_h)
ke_total = out_channels * in_channels * kernel_w * kernel_h
ke_shape = (out_channels, in_channels, kernel_w, kernel_h)
in_matrix = torch.arange(input_total).reshape(input_shape).to(torch.float)
ke_matrix = torch.arange(ke_total).reshape(ke_shape).to(torch.float)
print(f"in_matrix.shape=\n{in_matrix.shape}")
print(f"ke_matrix.shape=\n{ke_matrix.shape}")
print(f"in_matrix=\n{in_matrix}")
print(f"ke_matrix=\n{ke_matrix}")
my_conv2d = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_h, groups=2,
bias=False)
my_conv2d_weight = my_conv2d.weight
test_weight = torch.arange(8).reshape((2, 1, 2, 2)).to(torch.float)
my_conv2d.weight = nn.Parameter(test_weight)
# print(f"my_conv2d_weight=\n{my_conv2d_weight}")
# print(f"my_conv2d_weight.shape=\n{my_conv2d_weight.shape}")
print(f"test_weight=\n{test_weight}")
out_matrix = my_conv2d(in_matrix)
print(f"out_matrix=\n{out_matrix}")
print(f"out_matrix.shape={out_matrix.shape}")
python
复制代码
in_matrix.shape=
torch.Size([2, 2, 3, 3])
ke_matrix.shape=
torch.Size([2, 2, 2, 2])
in_matrix=
tensor([[[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.]],
[[ 9., 10., 11.],
[12., 13., 14.],
[15., 16., 17.]]],
[[[18., 19., 20.],
[21., 22., 23.],
[24., 25., 26.]],
[[27., 28., 29.],
[30., 31., 32.],
[33., 34., 35.]]]])
ke_matrix=
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]],
[[[ 8., 9.],
[10., 11.]],
[[12., 13.],
[14., 15.]]]])
test_weight=
tensor([[[[0., 1.],
[2., 3.]]],
[[[4., 5.],
[6., 7.]]]])
out_matrix=
tensor([[[[ 19., 25.],
[ 37., 43.]],
[[249., 271.],
[315., 337.]]],
[[[127., 133.],
[145., 151.]],
[[645., 667.],
[711., 733.]]]], grad_fn=<ConvolutionBackward0>)
out_matrix.shape=torch.Size([2, 2, 2, 2])