PyTorch torch.cat
- [1. `torch.cat`](#1. torch.cat)
- [2. Example](#2. Example)
- [3. Example](#3. Example)
- References
torch
https://pytorch.org/docs/stable/torch.html
- torch.cat(Python function, in- torch.cat)
1. torch.cat
https://pytorch.org/docs/stable/generated/torch.cat.html
torch.cat(tensors, dim=0, *, out=None) -> TensorConcatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).
在给定维度上连接给定的 seq 张量序列。所有张量必须具有相同的形状 (连接维度除外),或者是一个大小为 (0,) 的一维空张量。
torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
torch.cat() 可以看作是 torch.split() 和 torch.chunk() 的逆运算。
torch.cat() can be best understood via examples.
torch.stack() concatenates the given sequence along a new dimension.
torch.stack() 沿着新维度连接给定的序列。
- Parameters
tensors (sequence of Tensors) - any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
任何相同类型的张量 Python 序列。提供的非空张量必须具有相同的形状,连接维度除外。
dim (int, optional) - the dimension over which the tensors are concatenated
连接张量的维度
- Keyword Arguments
out (Tensor, optional) - the output tensor.
2. Example
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 0)
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 1)
tensor([[ 0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818]])
>>>
>>> exit()
(base) yongqiang@yongqiang:~$3. Example
https://github.com/karpathy/llama2.c/blob/master/model.py
import torch
idxs = torch.randn(1, 5)
print("idxs.shape:", idxs.shape)
print("idxs:\n", idxs)
next_idx = torch.randn(1, 1)
print("\nnext_idx.shape:", next_idx.shape)
print("next_idx:\n", next_idx)
print("\nidxs.size(1):", idxs.size(1))
idxs_set = torch.cat((idxs, next_idx), dim=1)
print("\nidxs_set.shape:", idxs_set.shape)
print("idxs_set:\n", idxs_set)
/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py 
idxs.shape: torch.Size([1, 5])
idxs:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691]])
next_idx.shape: torch.Size([1, 1])
next_idx:
 tensor([[0.4807]])
idxs.size(1): 5
idxs_set.shape: torch.Size([1, 6])
idxs_set:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691,  0.4807]])
Process finished with exit code 0References
1\] Yongqiang Cheng,