pytorch笔记:split

torch.split 是 PyTorch 中的一个函数,用于将张量按指定的大小或张量数量进行分割

1 基本使用方法

python 复制代码
torch.split(tensor, split_size_or_sections, dim=0)

|----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------|
| tensor | 要分割的输入张量 |
| split_size_or_sections | 以是整数或整数列表。 * 如果是整数,那么它表示每个分割的大小。如果张量在给定维度上的大小不能被该值整除,最后一段会小于其他段。 * 如果是整数列表,那么它表示每个分割的确切大小。列表的总和必须等于张量在给定维度上的大小。 * 使用整数列表时,确保其元素之和等于所分割维度的大小 |
| dim | 要分割的维度,默认值为0 |

返回一个张量的元组,其中每个张量是原始张量的一个分割。

也可以直接tensor.split(...)

2 举例

python 复制代码
import torch

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
torch.split(x, 3)
#(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8]))


torch.split(x, [2, 4, 2])
#(tensor([1, 2]), tensor([3, 4, 5, 6]), tensor([7, 8]))
python 复制代码
a = torch.arange(12).reshape(3,4)
a
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
'''

torch.split(a,2)
'''
(tensor([[0, 1, 2, 3],
         [4, 5, 6, 7]]),
 tensor([[ 8,  9, 10, 11]]))
'''

torch.split(a,2,dim=1)
'''
(tensor([[0, 1],
         [4, 5],
         [8, 9]]),
 tensor([[ 2,  3],
         [ 6,  7],
         [10, 11]]))
'''
相关推荐
是上好佳佳佳呀22 分钟前
【数据分析|Day02】Matplotlib 数据可视化笔记
笔记·matplotlib
WPF工业上位机7 小时前
YXGK.FakeVM深度学习之5语义分割
人工智能·深度学习
落叶无情7 小时前
ICEF认知操作系统:四类约束全维度全覆盖,是全谱系系统化约束体系
人工智能
碳基硅坊7 小时前
Gemma 4 12B 让AI创作更私密更高效
人工智能·gemma-4-12b
weixin_468466857 小时前
大模型新手入门与实战指南
人工智能·深度学习·ai·大模型
装不满的克莱因瓶7 小时前
掌握 RNN 与 LSTM 模型结构
人工智能·python·rnn·深度学习·神经网络·ai·lstm
jeffer_liu7 小时前
Spring AI 生产级实战:裁判员
java·人工智能·后端·spring·大模型
weixin_446260857 小时前
Agent 会自行回避吗?测量 LLM 智能体合规性的带内访问拒绝信号
人工智能
努力学习_小白7 小时前
ResNeXt-50——学习记录
pytorch·深度学习·学习