pytorch笔记:split

torch.split 是 PyTorch 中的一个函数,用于将张量按指定的大小或张量数量进行分割

1 基本使用方法

python 复制代码
torch.split(tensor, split_size_or_sections, dim=0)

|----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------|
| tensor | 要分割的输入张量 |
| split_size_or_sections | 以是整数或整数列表。 * 如果是整数,那么它表示每个分割的大小。如果张量在给定维度上的大小不能被该值整除,最后一段会小于其他段。 * 如果是整数列表,那么它表示每个分割的确切大小。列表的总和必须等于张量在给定维度上的大小。 * 使用整数列表时,确保其元素之和等于所分割维度的大小 |
| dim | 要分割的维度,默认值为0 |

返回一个张量的元组,其中每个张量是原始张量的一个分割。

也可以直接tensor.split(...)

2 举例

python 复制代码
import torch

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
torch.split(x, 3)
#(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8]))


torch.split(x, [2, 4, 2])
#(tensor([1, 2]), tensor([3, 4, 5, 6]), tensor([7, 8]))
python 复制代码
a = torch.arange(12).reshape(3,4)
a
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
'''

torch.split(a,2)
'''
(tensor([[0, 1, 2, 3],
         [4, 5, 6, 7]]),
 tensor([[ 8,  9, 10, 11]]))
'''

torch.split(a,2,dim=1)
'''
(tensor([[0, 1],
         [4, 5],
         [8, 9]]),
 tensor([[ 2,  3],
         [ 6,  7],
         [10, 11]]))
'''
相关推荐
多巴胺与内啡肽.8 分钟前
OpenCV进阶操作:风格迁移以及DNN模块解析
人工智能·opencv·dnn
szxinmai主板定制专家41 分钟前
基于TI AM6442+FPGA解决方案,支持6网口,4路CAN,8个串口
arm开发·人工智能·fpga开发
龙湾开发1 小时前
轻量级高性能推理引擎MNN 学习笔记 02.MNN主要API
人工智能·笔记·学习·机器学习·mnn
CopyLower1 小时前
Java与AI技术结合:从机器学习到生成式AI的实践
java·人工智能·机器学习
Tech Synapse1 小时前
联邦学习图像分类实战:基于FATE与PyTorch的隐私保护机器学习系统构建指南
pytorch·机器学习·分类
workflower2 小时前
使用谱聚类将相似度矩阵分为2类
人工智能·深度学习·算法·机器学习·设计模式·软件工程·软件需求
jndingxin2 小时前
OpenCV CUDA 模块中在 GPU 上对图像或矩阵进行 翻转(镜像)操作的一个函数 flip()
人工智能·opencv
HappyAcmen2 小时前
线代第二章矩阵第八节逆矩阵、解矩阵方程
笔记·学习·线性代数·矩阵
囚生CY2 小时前
【速写】TRL:Trainer的细节与思考(PPO/DPO+LoRA可行性)
人工智能
杨德兴2 小时前
3.3 阶数的作用
人工智能·学习