PyTorch torch.cat

PyTorch torch.cat

  • [1. `torch.cat`](#1. torch.cat)
  • [2. Example](#2. Example)
  • [3. Example](#3. Example)
  • References

torch
https://pytorch.org/docs/stable/torch.html

  • torch.cat (Python function, in torch.cat)

1. torch.cat

https://pytorch.org/docs/stable/generated/torch.cat.html

复制代码
torch.cat(tensors, dim=0, *, out=None) -> Tensor

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).

在给定维度上连接给定的 seq 张量序列。所有张量必须具有相同的形状 (连接维度除外),或者是一个大小为 (0,) 的一维空张量。

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
torch.cat() 可以看作是 torch.split()torch.chunk() 的逆运算。

torch.cat() can be best understood via examples.

torch.stack() concatenates the given sequence along a new dimension.
torch.stack() 沿着新维度连接给定的序列。

  • Parameters

tensors (sequence of Tensors) - any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.

任何相同类型的张量 Python 序列。提供的非空张量必须具有相同的形状,连接维度除外。

dim (int, optional) - the dimension over which the tensors are concatenated

连接张量的维度

  • Keyword Arguments

out (Tensor, optional) - the output tensor.

2. Example

复制代码
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 0)
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 1)
tensor([[ 0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818]])
>>>
>>> exit()
(base) yongqiang@yongqiang:~$

3. Example

https://github.com/karpathy/llama2.c/blob/master/model.py

复制代码
import torch

idxs = torch.randn(1, 5)
print("idxs.shape:", idxs.shape)
print("idxs:\n", idxs)

next_idx = torch.randn(1, 1)
print("\nnext_idx.shape:", next_idx.shape)
print("next_idx:\n", next_idx)

print("\nidxs.size(1):", idxs.size(1))
idxs_set = torch.cat((idxs, next_idx), dim=1)
print("\nidxs_set.shape:", idxs_set.shape)
print("idxs_set:\n", idxs_set)

/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py 
idxs.shape: torch.Size([1, 5])
idxs:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691]])

next_idx.shape: torch.Size([1, 1])
next_idx:
 tensor([[0.4807]])

idxs.size(1): 5

idxs_set.shape: torch.Size([1, 6])
idxs_set:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691,  0.4807]])

Process finished with exit code 0

References

1\] Yongqiang Cheng,

相关推荐
伊织code4 小时前
PyTorch API 6
pytorch·api·ddp
范男7 小时前
基于Pytochvideo训练自己的的视频分类模型
人工智能·pytorch·python·深度学习·计算机视觉·3d·视频
伊织code12 小时前
PyTorch API 7
pytorch·api·张量·稀疏
聚客AI14 小时前
深度拆解AI大模型从训练框架、推理优化到市场趋势与基础设施挑战
图像处理·人工智能·pytorch·深度学习·机器学习·自然语言处理·transformer
大力水手(Popeye)17 小时前
Pytorch——tensor
人工智能·pytorch·python
Caven771 天前
【pytorch】reshape的使用
pytorch·python
无规则ai1 天前
动手学深度学习(pytorch版):第四章节—多层感知机(5)权重衰减
人工智能·pytorch·python·深度学习
雷达学弱狗1 天前
backward怎么计算的是torch.tensor(2.0, requires_grad=True)变量的梯度
人工智能·pytorch·深度学习
抠头专注python环境配置2 天前
Pytorch GPU版本安装保姆级教程
pytorch·python·深度学习·conda
爱分享的飘哥2 天前
第七十章:告别“手写循环”噩梦!Trainer结构搭建:PyTorch Lightning让你“一键炼丹”!
人工智能·pytorch·分布式训练·lightning·accelerate·训练框架·trainer