torch.tile 手动实现 kron+矩阵乘法

文章目录

  • [1. tile](#1. tile)
  • [2. pytorch](#2. pytorch)

1. tile

torch.tile 是对矩阵进行指定维度的复制数据,为了实现矩阵复制,使用kron 算子将对角矩阵I 复制后形成基于行变换和列变换的矩阵

2. pytorch

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    a_matrix = torch.randn(2, 3)
    dim0 = 4
    dim1 = 3
    tile_matrix = torch.tile(a_matrix, dims=(dim0, dim1))
    print(f"a_matrix.shape=\n{a_matrix.shape}")
    print(f"tile_matrix.shape=\n{tile_matrix.shape}")
    print(f"a_matrix=\n{a_matrix}")
    print(f"tile_matrix=\n{tile_matrix}")
    my_one = torch.zeros(2 * dim0, 2)
    my_one[0::2, 0] = 1
    my_one[1::2, 1] = 1
    print(f"my_one=\n{my_one}")
    a_one = torch.ones(2).reshape(-1, 1)
    a_row = torch.eye(2)
    a_kron = torch.kron(a_one, a_row)
    print(f"a_kron=\n{a_kron}")
    a_co_one = torch.ones(3).reshape(1, -1)
    a_column = torch.eye(3)
    b_kron = torch.kron(a_co_one, a_column)
    print(f"b_kron=\n{b_kron}")
    my_one_result = my_one @ a_matrix @ b_kron
    print(f"my_one_result=\n{my_one_result}")
    m_check_result = torch.allclose(my_one_result,tile_matrix)
    print(f"m_check_result={m_check_result}")
  • 结果:
python 复制代码
a_matrix.shape=
torch.Size([2, 3])
tile_matrix.shape=
torch.Size([8, 9])
a_matrix=
tensor([[0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886]])
tile_matrix=
tensor([[0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886],
        [0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886],
        [0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886],
        [0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886]])
my_one=
tensor([[1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.]])
a_kron=
tensor([[1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.]])
b_kron=
tensor([[1., 0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0., 1.]])
my_one_result=
tensor([[0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886],
        [0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886],
        [0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886],
        [0.340, 0.766, 0.622, 0.340, 0.766, 0.622, 0.340, 0.766, 0.622],
        [0.366, 1.425, 0.886, 0.366, 1.425, 0.886, 0.366, 1.425, 0.886]])
m_check_result=True
相关推荐
CV缝合救星1 小时前
【Arxiv 2025 预发行论文】重磅突破!STAR-DSSA 模块横空出世:显著性+拓扑双重加持,小目标、大场景统统拿下!
人工智能·深度学习·计算机视觉·目标跟踪·即插即用模块
蓝桉8023 小时前
如何进行神经网络的模型训练(视频代码中的知识点记录)
人工智能·深度学习·神经网络
星期天要睡觉4 小时前
深度学习——数据增强(Data Augmentation)
人工智能·深度学习
笑脸惹桃花4 小时前
50系显卡训练深度学习YOLO等算法报错的解决方法
深度学习·算法·yolo·torch·cuda
anneCoder5 小时前
AI大模型应用研发工程师面试知识准备目录
人工智能·深度学习·机器学习
骑驴看星星a5 小时前
没有深度学习
人工智能·深度学习
THMAIL7 小时前
深度学习从入门到精通 - AutoML与神经网络搜索(NAS):自动化模型设计未来
人工智能·python·深度学习·神经网络·算法·机器学习·逻辑回归
山烛7 小时前
深度学习:残差网络ResNet与迁移学习
人工智能·python·深度学习·残差网络·resnet·迁移学习
THMAIL12 小时前
量化基金从小白到大师 - 金融数据获取大全:从免费API到Tick级数据实战指南
人工智能·python·深度学习·算法·机器学习·金融·kafka
Tiger Z13 小时前
《动手学深度学习v2》学习笔记 | 2.4 微积分 & 2.5 自动微分
pytorch·深度学习·ai