文章目录
- [1. excel 示意](#1. excel 示意)
- [2. pytorch代码](#2. pytorch代码)
- [3. window mhsa](#3. window mhsa)
1. excel 示意
将一个三维矩阵按照window的大小进行拆分成多块2x2窗口矩阵,具体如下图所示
2. pytorch代码
- pytorch源码
python
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
batch_size = 2
seq_len = 4
model_dim = 6
patch_total = batch_size * seq_len * model_dim
patch = torch.arange(patch_total).reshape((batch_size, seq_len, model_dim)).to(torch.float32)
print(f"patch.shape=\n{patch.shape}")
print(f"patch=\n{patch}")
patch_unfold = F.unfold(input=patch, kernel_size=(2, 2), stride=(2, 2))
print(f"patch_unfold.shape=\n{patch_unfold.shape}")
print(f"patch_unfold=\n{patch_unfold}")
# patch_unfold = patch_unfold.transpose(-1, -2)
print(f"patch_unfold=\n{patch_unfold}")
patch_nums = patch_unfold.reshape(batch_size, 4, 6)
print(f"patch_nums=\n{patch_nums}")
patch_nums_new = patch_nums.transpose(-1, -2)
print(f"patch_nums_new.shape=\n{patch_nums_new.shape}")
print(f"patch_nums_new=\n{patch_nums_new}")
patch_nums_final = patch_nums_new.reshape(12, 2, 2)
print(f"patch_nums_final.shape=\n{patch_nums_final.shape}")
print(f"patch_nums_final=\n{patch_nums_final}")
- 结果:
python
patch.shape=
torch.Size([2, 4, 6])
patch=
tensor([[[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.],
[12., 13., 14., 15., 16., 17.],
[18., 19., 20., 21., 22., 23.]],
[[24., 25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34., 35.],
[36., 37., 38., 39., 40., 41.],
[42., 43., 44., 45., 46., 47.]]])
patch_unfold.shape=
torch.Size([8, 6])
patch_unfold=
tensor([[ 0., 2., 4., 12., 14., 16.],
[ 1., 3., 5., 13., 15., 17.],
[ 6., 8., 10., 18., 20., 22.],
[ 7., 9., 11., 19., 21., 23.],
[24., 26., 28., 36., 38., 40.],
[25., 27., 29., 37., 39., 41.],
[30., 32., 34., 42., 44., 46.],
[31., 33., 35., 43., 45., 47.]])
patch_unfold=
tensor([[ 0., 2., 4., 12., 14., 16.],
[ 1., 3., 5., 13., 15., 17.],
[ 6., 8., 10., 18., 20., 22.],
[ 7., 9., 11., 19., 21., 23.],
[24., 26., 28., 36., 38., 40.],
[25., 27., 29., 37., 39., 41.],
[30., 32., 34., 42., 44., 46.],
[31., 33., 35., 43., 45., 47.]])
patch_nums=
tensor([[[ 0., 2., 4., 12., 14., 16.],
[ 1., 3., 5., 13., 15., 17.],
[ 6., 8., 10., 18., 20., 22.],
[ 7., 9., 11., 19., 21., 23.]],
[[24., 26., 28., 36., 38., 40.],
[25., 27., 29., 37., 39., 41.],
[30., 32., 34., 42., 44., 46.],
[31., 33., 35., 43., 45., 47.]]])
patch_nums_new.shape=
torch.Size([2, 6, 4])
patch_nums_new=
tensor([[[ 0., 1., 6., 7.],
[ 2., 3., 8., 9.],
[ 4., 5., 10., 11.],
[12., 13., 18., 19.],
[14., 15., 20., 21.],
[16., 17., 22., 23.]],
[[24., 25., 30., 31.],
[26., 27., 32., 33.],
[28., 29., 34., 35.],
[36., 37., 42., 43.],
[38., 39., 44., 45.],
[40., 41., 46., 47.]]])
patch_nums_final.shape=
torch.Size([12, 2, 2])
patch_nums_final=
tensor([[[ 0., 1.],
[ 6., 7.]],
[[ 2., 3.],
[ 8., 9.]],
[[ 4., 5.],
[10., 11.]],
[[12., 13.],
[18., 19.]],
[[14., 15.],
[20., 21.]],
[[16., 17.],
[22., 23.]],
[[24., 25.],
[30., 31.]],
[[26., 27.],
[32., 33.]],
[[28., 29.],
[34., 35.]],
[[36., 37.],
[42., 43.]],
[[38., 39.],
[44., 45.]],
[[40., 41.],
[46., 47.]]])
3. window mhsa
- excel 示意图
- pytorch
python
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
bs = 2
num_patch = 16
patch_depth = 4
window_size = 2
image_height = image_width = 4
num_patch_in_window = window_size * window_size
patch_total = bs * num_patch * patch_depth
patch_embedding = torch.arange(patch_total).reshape((bs, num_patch, patch_depth)).to(torch.float32)
print(f"patch_embedding.shape=\n{patch_embedding.shape}")
print(f"patch_embedding=\n{patch_embedding}")
patch_embedding = patch_embedding.transpose(-1, -2)
patch = patch_embedding.reshape(bs, patch_depth, image_height, image_width)
print(f"patch=\n{patch}")
window = F.unfold(patch, kernel_size=(window_size, window_size), stride=(window_size, window_size)).transpose(-1,
-2)
print(f"window.shape=\n{window.shape}")
print(f"window=\n{window}")
bs, num_window, patch_depth_times_num_patch_in_window = window.shape
window = window.reshape(bs*num_window,patch_depth,num_patch_in_window).transpose(-1,-2)
print(f"window.shape=\n{window.shape}")
print(f"window=\n{window}")
- 结果:
python
patch_embedding.shape=
torch.Size([2, 16, 4])
patch_embedding=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 12., 13., 14., 15.],
[ 16., 17., 18., 19.],
[ 20., 21., 22., 23.],
[ 24., 25., 26., 27.],
[ 28., 29., 30., 31.],
[ 32., 33., 34., 35.],
[ 36., 37., 38., 39.],
[ 40., 41., 42., 43.],
[ 44., 45., 46., 47.],
[ 48., 49., 50., 51.],
[ 52., 53., 54., 55.],
[ 56., 57., 58., 59.],
[ 60., 61., 62., 63.]],
[[ 64., 65., 66., 67.],
[ 68., 69., 70., 71.],
[ 72., 73., 74., 75.],
[ 76., 77., 78., 79.],
[ 80., 81., 82., 83.],
[ 84., 85., 86., 87.],
[ 88., 89., 90., 91.],
[ 92., 93., 94., 95.],
[ 96., 97., 98., 99.],
[100., 101., 102., 103.],
[104., 105., 106., 107.],
[108., 109., 110., 111.],
[112., 113., 114., 115.],
[116., 117., 118., 119.],
[120., 121., 122., 123.],
[124., 125., 126., 127.]]])
patch=
tensor([[[[ 0., 4., 8., 12.],
[ 16., 20., 24., 28.],
[ 32., 36., 40., 44.],
[ 48., 52., 56., 60.]],
[[ 1., 5., 9., 13.],
[ 17., 21., 25., 29.],
[ 33., 37., 41., 45.],
[ 49., 53., 57., 61.]],
[[ 2., 6., 10., 14.],
[ 18., 22., 26., 30.],
[ 34., 38., 42., 46.],
[ 50., 54., 58., 62.]],
[[ 3., 7., 11., 15.],
[ 19., 23., 27., 31.],
[ 35., 39., 43., 47.],
[ 51., 55., 59., 63.]]],
[[[ 64., 68., 72., 76.],
[ 80., 84., 88., 92.],
[ 96., 100., 104., 108.],
[112., 116., 120., 124.]],
[[ 65., 69., 73., 77.],
[ 81., 85., 89., 93.],
[ 97., 101., 105., 109.],
[113., 117., 121., 125.]],
[[ 66., 70., 74., 78.],
[ 82., 86., 90., 94.],
[ 98., 102., 106., 110.],
[114., 118., 122., 126.]],
[[ 67., 71., 75., 79.],
[ 83., 87., 91., 95.],
[ 99., 103., 107., 111.],
[115., 119., 123., 127.]]]])
window.shape=
torch.Size([2, 4, 16])
window=
tensor([[[ 0., 4., 16., 20., 1., 5., 17., 21., 2., 6., 18.,
22., 3., 7., 19., 23.],
[ 8., 12., 24., 28., 9., 13., 25., 29., 10., 14., 26.,
30., 11., 15., 27., 31.],
[ 32., 36., 48., 52., 33., 37., 49., 53., 34., 38., 50.,
54., 35., 39., 51., 55.],
[ 40., 44., 56., 60., 41., 45., 57., 61., 42., 46., 58.,
62., 43., 47., 59., 63.]],
[[ 64., 68., 80., 84., 65., 69., 81., 85., 66., 70., 82.,
86., 67., 71., 83., 87.],
[ 72., 76., 88., 92., 73., 77., 89., 93., 74., 78., 90.,
94., 75., 79., 91., 95.],
[ 96., 100., 112., 116., 97., 101., 113., 117., 98., 102., 114.,
118., 99., 103., 115., 119.],
[104., 108., 120., 124., 105., 109., 121., 125., 106., 110., 122.,
126., 107., 111., 123., 127.]]])
window.shape=
torch.Size([8, 4, 4])
window=
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 16., 17., 18., 19.],
[ 20., 21., 22., 23.]],
[[ 8., 9., 10., 11.],
[ 12., 13., 14., 15.],
[ 24., 25., 26., 27.],
[ 28., 29., 30., 31.]],
[[ 32., 33., 34., 35.],
[ 36., 37., 38., 39.],
[ 48., 49., 50., 51.],
[ 52., 53., 54., 55.]],
[[ 40., 41., 42., 43.],
[ 44., 45., 46., 47.],
[ 56., 57., 58., 59.],
[ 60., 61., 62., 63.]],
[[ 64., 65., 66., 67.],
[ 68., 69., 70., 71.],
[ 80., 81., 82., 83.],
[ 84., 85., 86., 87.]],
[[ 72., 73., 74., 75.],
[ 76., 77., 78., 79.],
[ 88., 89., 90., 91.],
[ 92., 93., 94., 95.]],
[[ 96., 97., 98., 99.],
[100., 101., 102., 103.],
[112., 113., 114., 115.],
[116., 117., 118., 119.]],
[[104., 105., 106., 107.],
[108., 109., 110., 111.],
[120., 121., 122., 123.],
[124., 125., 126., 127.]]])