相对位置2d矩阵和kron运算的思考

文章目录

  • [1. 相对位置矩阵2d](#1. 相对位置矩阵2d)
  • [2. kron运算](#2. kron运算)

1. 相对位置矩阵2d

在swin-transformer中,我们会计算每个patch之间的相对位置,那么我们看到有一连串的拉伸和相减,直接贴代码:

python 复制代码
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False,threshold=torch.inf)

if __name__ == "__main__":
    run_code = 2
    x_len = 5
    y_len = 5
    x_tensor = torch.arange(x_len)
    y_tensor = torch.arange(y_len)
    x_meshgrid, y_meshgrid = torch.meshgrid(x_tensor, y_tensor)
    print(f"x_tensor=\n{x_tensor}")
    print(f"y_tensor=\n{y_tensor}")
    print(f"x_meshgrid=\n{x_meshgrid}")
    print(f"x_meshgrid.shape=\n{x_meshgrid.shape}")
    print(f"y_meshgrid.shape=\n{y_meshgrid.shape}")
    print(f"y_meshgrid=\n{y_meshgrid}")
    stack_meshgrid = torch.stack(torch.meshgrid(x_tensor, y_tensor))
    print(f"stack_meshgrid.shape=\n{stack_meshgrid.shape}")
    print(f"stack_meshgrid=\n{stack_meshgrid}")
    stack_meshgrid_flatten = torch.flatten(stack_meshgrid, 1)
    print(f"stack_meshgrid_flatten.shape=\n{stack_meshgrid_flatten.shape}")
    print(f"stack_meshgrid_flatten=\n{stack_meshgrid_flatten}")
    stack_meshgrid_flatten_1 = stack_meshgrid_flatten[:, None, :]
    stack_meshgrid_flatten_2 = stack_meshgrid_flatten[:, :, None]
    relative_coords_bias = stack_meshgrid_flatten_2 - stack_meshgrid_flatten_1
    print(f"stack_meshgrid_flatten_1=\n{stack_meshgrid_flatten_1}")
    print(f"stack_meshgrid_flatten_2=\n{stack_meshgrid_flatten_2}")
    print(f"relative_coords_bias=\n{relative_coords_bias}")
    relative_coords_bias[0, :, :] += x_len
    relative_coords_bias[1, :, :] += y_len
    print(f"relative_coords_bias=\n{relative_coords_bias}")
  • result:
python 复制代码
x_tensor=
tensor([0, 1, 2, 3, 4])
y_tensor=
tensor([0, 1, 2, 3, 4])
x_meshgrid=
tensor([[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4]])
x_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid.shape=
torch.Size([5, 5])
y_meshgrid=
tensor([[0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]])
stack_meshgrid.shape=
torch.Size([2, 5, 5])
stack_meshgrid=
tensor([[[0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])
stack_meshgrid_flatten.shape=
torch.Size([2, 25])
stack_meshgrid_flatten=
tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,
         4],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
         4]])
stack_meshgrid_flatten_1=
tensor([[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4,
          4, 4]],

        [[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2,
          3, 4]]])
stack_meshgrid_flatten_2=
tensor([[[0],
         [0],
         [0],
         [0],
         [0],
         [1],
         [1],
         [1],
         [1],
         [1],
         [2],
         [2],
         [2],
         [2],
         [2],
         [3],
         [3],
         [3],
         [3],
         [3],
         [4],
         [4],
         [4],
         [4],
         [4]],

        [[0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4],
         [0],
         [1],
         [2],
         [3],
         [4]]])
relative_coords_bias=
tensor([[[ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -3, -3,
          -3, -3, -3, -4, -4, -4, -4, -4],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1, -1, -1, -1, -2, -2,
          -2, -2, -2, -3, -3, -3, -3, -3],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,  0,  0,  0, -1, -1,
          -1, -1, -1, -2, -2, -2, -2, -2],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,  1,  1,  1,  0,  0,
           0,  0,  0, -1, -1, -1, -1, -1],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  3,  3,  3,  3,  3,  2,  2,  2,  2,  2,  1,  1,
           1,  1,  1,  0,  0,  0,  0,  0]],

        [[ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0],
         [ 0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1, -2, -3, -4,  0, -1,
          -2, -3, -4,  0, -1, -2, -3, -4],
         [ 1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0, -1, -2, -3,  1,  0,
          -1, -2, -3,  1,  0, -1, -2, -3],
         [ 2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,  0, -1, -2,  2,  1,
           0, -1, -2,  2,  1,  0, -1, -2],
         [ 3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,  1,  0, -1,  3,  2,
           1,  0, -1,  3,  2,  1,  0, -1],
         [ 4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,  2,  1,  0,  4,  3,
           2,  1,  0,  4,  3,  2,  1,  0]]])
relative_coords_bias=
tensor([[[5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 1, 1, 1,
          1, 1],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2,
          2, 2],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 3, 3, 3,
          3, 3],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 4, 4, 4,
          4, 4],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5],
         [9, 9, 9, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5, 5, 5,
          5, 5]],

        [[5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5],
         [5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3, 2, 1, 5, 4, 3,
          2, 1],
         [6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4, 3, 2, 6, 5, 4,
          3, 2],
         [7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5, 4, 3, 7, 6, 5,
          4, 3],
         [8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6, 5, 4, 8, 7, 6,
          5, 4],
         [9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7, 6, 5, 9, 8, 7,
          6, 5]]])

2. kron运算

在结果中,我们发现很多重复的值,这就让我联想到kron运算。

  • step1:形成子矩阵
  • step2: kron
  • pytorch
python 复制代码
import torch
import torch.nn as nn

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == '__main__':
    run_code = 0
    height = 5
    width = 5
    a_vector = torch.arange(width).to(torch.float).reshape(-1, 1)
    a_ones = torch.ones(1, width)
    a_matrix = a_vector @ a_ones
    print(f"a_matrix=\n{a_matrix}")
    b_matrix = a_matrix - a_matrix.T
    print(f"b_matrix=\n{b_matrix}")
    b_matrix_ones = torch.ones_like(b_matrix)
    ab_kron = torch.kron(b_matrix,b_matrix_ones)
    print(f"ab_kron=\n{ab_kron}")
    final_ab = ab_kron+5
    print(f"final_ab=\n{final_ab}")
  • result:
python 复制代码
a_matrix=
tensor([[0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4.]])
b_matrix=
tensor([[ 0., -1., -2., -3., -4.],
        [ 1.,  0., -1., -2., -3.],
        [ 2.,  1.,  0., -1., -2.],
        [ 3.,  2.,  1.,  0., -1.],
        [ 4.,  3.,  2.,  1.,  0.]])
ab_kron=
tensor([[ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1., -2., -2., -2., -2.,
         -2., -3., -3., -3., -3., -3., -4., -4., -4., -4., -4.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1.,
         -1., -2., -2., -2., -2., -2., -3., -3., -3., -3., -3.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,
          0., -1., -1., -1., -1., -1., -2., -2., -2., -2., -2.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.,
          1.,  0.,  0.,  0.,  0.,  0., -1., -1., -1., -1., -1.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 4.,  4.,  4.,  4.,  4.,  3.,  3.,  3.,  3.,  3.,  2.,  2.,  2.,  2.,
          2.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.]])
final_ab=
tensor([[5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3., 3., 3., 2., 2., 2.,
         2., 2., 1., 1., 1., 1., 1.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4., 4., 4., 3., 3., 3.,
         3., 3., 2., 2., 2., 2., 2.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5., 5., 5., 4., 4., 4.,
         4., 4., 3., 3., 3., 3., 3.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6., 6., 6., 5., 5., 5.,
         5., 5., 4., 4., 4., 4., 4.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.],
        [9., 9., 9., 9., 9., 8., 8., 8., 8., 8., 7., 7., 7., 7., 7., 6., 6., 6.,
         6., 6., 5., 5., 5., 5., 5.]])
相关推荐
老胖闲聊1 小时前
Python Rio 【图像处理】库简介
开发语言·图像处理·python
码界奇点1 小时前
Python Flask文件处理与异常处理实战指南
开发语言·python·自然语言处理·flask·python3.11
浠寒AI1 小时前
智能体模式篇(上)- 深入 ReAct:LangGraph构建能自主思考与行动的 AI
人工智能·python
weixin_505154462 小时前
数字孪生在建设智慧城市中可以起到哪些作用或帮助?
大数据·人工智能·智慧城市·数字孪生·数据可视化
Best_Me072 小时前
深度学习模块缝合
人工智能·深度学习
YuTaoShao2 小时前
【论文阅读】YOLOv8在单目下视多车目标检测中的应用
人工智能·yolo·目标检测
行云流水剑2 小时前
【学习记录】如何使用 Python 提取 PDF 文件中的内容
python·学习·pdf
算家计算2 小时前
字节开源代码模型——Seed-Coder 本地部署教程,模型自驱动数据筛选,让每行代码都精准落位!
人工智能·开源
伪_装3 小时前
大语言模型(LLM)面试问题集
人工智能·语言模型·自然语言处理
gs801403 小时前
Tavily 技术详解:为大模型提供实时搜索增强的利器
人工智能·rag