torch.split
是 PyTorch 中的一个函数,用于将张量按指定的大小或张量数量进行分割
1 基本使用方法
python
torch.split(tensor, split_size_or_sections, dim=0)
|----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------|
| tensor | 要分割的输入张量 |
| split_size_or_sections | 以是整数或整数列表。 * 如果是整数,那么它表示每个分割的大小。如果张量在给定维度上的大小不能被该值整除,最后一段会小于其他段。 * 如果是整数列表,那么它表示每个分割的确切大小。列表的总和必须等于张量在给定维度上的大小。 * 使用整数列表时,确保其元素之和等于所分割维度的大小 |
| dim | 要分割的维度,默认值为0 |
返回一个张量的元组,其中每个张量是原始张量的一个分割。
也可以直接tensor.split(...)
2 举例
python
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
torch.split(x, 3)
#(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8]))
torch.split(x, [2, 4, 2])
#(tensor([1, 2]), tensor([3, 4, 5, 6]), tensor([7, 8]))
python
a = torch.arange(12).reshape(3,4)
a
'''
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
'''
torch.split(a,2)
'''
(tensor([[0, 1, 2, 3],
[4, 5, 6, 7]]),
tensor([[ 8, 9, 10, 11]]))
'''
torch.split(a,2,dim=1)
'''
(tensor([[0, 1],
[4, 5],
[8, 9]]),
tensor([[ 2, 3],
[ 6, 7],
[10, 11]]))
'''