PyTorch torch.cat

PyTorch torch.cat

  • [1. `torch.cat`](#1. torch.cat)
  • [2. Example](#2. Example)
  • [3. Example](#3. Example)
  • References

torch
https://pytorch.org/docs/stable/torch.html

  • torch.cat (Python function, in torch.cat)

1. torch.cat

https://pytorch.org/docs/stable/generated/torch.cat.html

复制代码
torch.cat(tensors, dim=0, *, out=None) -> Tensor

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).

在给定维度上连接给定的 seq 张量序列。所有张量必须具有相同的形状 (连接维度除外),或者是一个大小为 (0,) 的一维空张量。

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
torch.cat() 可以看作是 torch.split()torch.chunk() 的逆运算。

torch.cat() can be best understood via examples.

torch.stack() concatenates the given sequence along a new dimension.
torch.stack() 沿着新维度连接给定的序列。

  • Parameters

tensors (sequence of Tensors) - any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.

任何相同类型的张量 Python 序列。提供的非空张量必须具有相同的形状,连接维度除外。

dim (int, optional) - the dimension over which the tensors are concatenated

连接张量的维度

  • Keyword Arguments

out (Tensor, optional) - the output tensor.

2. Example

复制代码
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 0)
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 1)
tensor([[ 0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818]])
>>>
>>> exit()
(base) yongqiang@yongqiang:~$

3. Example

https://github.com/karpathy/llama2.c/blob/master/model.py

复制代码
import torch

idxs = torch.randn(1, 5)
print("idxs.shape:", idxs.shape)
print("idxs:\n", idxs)

next_idx = torch.randn(1, 1)
print("\nnext_idx.shape:", next_idx.shape)
print("next_idx:\n", next_idx)

print("\nidxs.size(1):", idxs.size(1))
idxs_set = torch.cat((idxs, next_idx), dim=1)
print("\nidxs_set.shape:", idxs_set.shape)
print("idxs_set:\n", idxs_set)

/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py 
idxs.shape: torch.Size([1, 5])
idxs:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691]])

next_idx.shape: torch.Size([1, 1])
next_idx:
 tensor([[0.4807]])

idxs.size(1): 5

idxs_set.shape: torch.Size([1, 6])
idxs_set:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691,  0.4807]])

Process finished with exit code 0

References

1\] Yongqiang Cheng,

相关推荐
ACEEE122212 小时前
Stanford CS336 | Assignment 2 - FlashAttention-v2 Pytorch & Triotn实现
人工智能·pytorch·python·深度学习·机器学习·nlp·transformer
深耕AI1 天前
【PyTorch训练】准确率计算(代码片段拆解)
人工智能·pytorch·python
nuczzz1 天前
pytorch非线性回归
人工智能·pytorch·机器学习·ai
~-~%%1 天前
Moe机制与pytorch实现
人工智能·pytorch·python
Garfield20051 天前
绕过 FlashAttention-2 限制:在 Turing 架构上使用 PyTorch 实现 FlashAttention
pytorch·flashattention·turing·图灵架构·t4·2080ti
深耕AI1 天前
【PyTorch训练】为什么要有 loss.backward() 和 optimizer.step()?
人工智能·pytorch·python
七芒星20231 天前
ResNet(详细易懂解释):残差网络的革命性突破
人工智能·pytorch·深度学习·神经网络·学习·cnn
九年义务漏网鲨鱼2 天前
【Debug日志 | DDP 下 BatchNorm 统计失真】
pytorch
☼←安于亥时→❦2 天前
PyTorch 梯度与微积分
人工智能·pytorch·python
缘友一世2 天前
PyTorch深度学习实战【10】之神经网络的损失函数
pytorch·深度学习·神经网络