相对位置2d矩阵和kron运算的思考

文章目录

  • [1. 相对位置矩阵2d](#1. 相对位置矩阵2d)
  • [2. kron运算](#2. kron运算)

1. 相对位置矩阵2d

在swin-transformer中,我们会计算每个patch之间的相对位置,那么我们看到有一连串的拉伸和相减,直接贴代码:

python 复制代码
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False,threshold=torch.inf)

if __name__ == "__main__":
    run_code = 2
    x_len = 5
    y_len = 5
    x_tensor = torch.arange(x_len)
    y_tensor = torch.arange(y_len)
    x_meshgrid, y_meshgrid = torch.meshgrid(x_tensor, y_tensor)
    print(f"x_tensor=\n{x_tensor}")
    print(f"y_tensor=\n{y_tensor}")
    print(f"x_meshgrid=\n{x_meshgrid}")
    print(f"x_meshgrid.shape=\n{x_meshgrid.shape}")
    print(f"y_meshgrid.shape=\n{y_meshgrid.shape}")
    print(f"y_meshgrid=\n{y_meshgrid}")
    stack_meshgrid = torch.stack(torch.meshgrid(x_tensor, y_tensor))
    print(f"stack_meshgrid.shape=\n{stack_meshgrid.shape}")
    print(f"stack_meshgrid=\n{stack_meshgrid}")
    stack_meshgrid_flatten = torch.flatten(stack_meshgrid, 1)
    print(f"stack_meshgrid_flatten.shape=\n{stack_meshgrid_flatten.shape}")
    print(f"stack_meshgrid_flatten=\n{stack_meshgrid_flatten}")
    stack_meshgrid_flatten_1 = stack_meshgrid_flatten[:, None, :]
    stack_meshgrid_flatten_2 = stack_meshgrid_flatten[:, :, None]
    relative_coords_bias = stack_meshgrid_flatten_2 - stack_meshgrid_flatten_1
    print(f"stack_meshgrid_flatten_1=\n{stack_meshgrid_flatten_1}")
    print(f"stack_meshgrid_flatten_2=\n{stack_meshgrid_flatten_2}")
    print(f"relative_coords_bias=\n{relative_coords_bias}")
    relative_coords_bias[0, :, :] += x_len
    relative_coords_bias[1, :, :] += y_len
    print(f"relative_coords_bias=\n{relative_coords_bias}")
  • result:
python 复制代码
x_tensor=
tensor([0, 1, 2, 3, 4])
y_tensor=
tensor([0, 1, 2, 3, 4])
x_meshgrid=
tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4]])
x_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid=
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])
stack_meshgrid.shape=
torch.Size([2, 5, 5])
stack_meshgrid=
tensor([[[0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])
stack_meshgrid_flatten.shape=
torch.Size([2, 25])
stack_meshgrid_flatten=
tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
         4]])
stack_meshgrid_flatten_1=
tensor([[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4,
          4, 4]],

        [[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2,
          3, 4]]])
stack_meshgrid_flatten_2=
tensor([[[0],
         [0],
         [0],
         [0],
         [0],
         [1],
         [1],
         [1],
         [1],
         [1],
         [2],
         [2],
         [2],
         [2],
         [2],
         [3],
         [3],
         [3],
         [3],
         [3],
         [4],
         [4],
         [4],
         [4],
         [4]],

        [[0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4]]])
relative_coords_bias=
tensor([[[ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0]],

        [[ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0]]])
relative_coords_bias=
tensor([[[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5]],

        [[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5]]])

2. kron运算

在结果中,我们发现很多重复的值,这就让我联想到kron运算。

  • step1:形成子矩阵
  • step2: kron
  • pytorch
python 复制代码
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == '__main__':
    run_code = 0
    height = 5
    width = 5
    a_vector = torch.arange(width).to(torch.float).reshape(-1, 1)
    a_ones = torch.ones(1, width)
    a_matrix = a_vector @ a_ones
    print(f"a_matrix=\n{a_matrix}")
    b_matrix = a_matrix - a_matrix.T
    print(f"b_matrix=\n{b_matrix}")
    b_matrix_ones = torch.ones_like(b_matrix)
    ab_kron = torch.kron(b_matrix,b_matrix_ones)
    print(f"ab_kron=\n{ab_kron}")
    final_ab = ab_kron+5
    print(f"final_ab=\n{final_ab}")
  • result:
python 复制代码
a_matrix=
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.]])
b_matrix=
tensor([[ 0., -1., -2., -3., -4.],
        [ 1.,  0., -1., -2., -3.],
        [ 2.,  1.,  0., -1., -2.],
        [ 3.,  2.,  1.,  0., -1.],
        [ 4.,  3.,  2.,  1.,  0.]])
ab_kron=
tensor([[ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.]])
final_ab=
tensor([[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.]])
相关推荐
本本的小橙子3 分钟前
第38周:文献阅读
人工智能·深度学习·tensorflow
LoserChaser6 分钟前
李宏毅机器学习笔记(1)—机器学习基本概念+深度学习基本概念
笔记·深度学习·机器学习
Wnq100729 分钟前
智慧城市智慧调度系统的架构与关键技术研究
人工智能·架构·智慧城市·big data
星辰大海的精灵9 分钟前
SpringAI轻松构建MCP Client-Server架构
人工智能·后端·架构
果冻人工智能11 分钟前
判断 Python 代码是不是 AI 写的几个简单方法
人工智能
奶油话梅糖27 分钟前
TensorFlow 深度学习框架详解
人工智能·深度学习·tensorflow
arbboter33 分钟前
【AI工具开发】Notepad++插件开发实践:从基础交互到ScintillaCall集成
人工智能·编辑器·notepad++·插件开发·scintilla·scintillacall·scintilla类封装
正经教主34 分钟前
【AI语音】edge-tts实现文本转语音,免费且音质不错
ide·人工智能·语音识别
go546315846538 分钟前
使用Python和PyTorch库实现基于DNN、CNN、LSTM的极化码译码器模型的代码示例
pytorch·python·dnn
勤劳打代码42 分钟前
端倪无际 —— cursor 配置 mcp 保姆级攻略
人工智能·mcp