PyTorch torch.cat

PyTorch torch.cat

  • [1. `torch.cat`](#1. torch.cat)
  • [2. Example](#2. Example)
  • [3. Example](#3. Example)
  • References

torch
https://pytorch.org/docs/stable/torch.html

  • torch.cat (Python function, in torch.cat)

1. torch.cat

https://pytorch.org/docs/stable/generated/torch.cat.html

复制代码
torch.cat(tensors, dim=0, *, out=None) -> Tensor

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).

在给定维度上连接给定的 seq 张量序列。所有张量必须具有相同的形状 (连接维度除外),或者是一个大小为 (0,) 的一维空张量。

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
torch.cat() 可以看作是 torch.split()torch.chunk() 的逆运算。

torch.cat() can be best understood via examples.

torch.stack() concatenates the given sequence along a new dimension.
torch.stack() 沿着新维度连接给定的序列。

  • Parameters

tensors (sequence of Tensors) - any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.

任何相同类型的张量 Python 序列。提供的非空张量必须具有相同的形状,连接维度除外。

dim (int, optional) - the dimension over which the tensors are concatenated

连接张量的维度

  • Keyword Arguments

out (Tensor, optional) - the output tensor.

2. Example

复制代码
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 0)
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 1)
tensor([[ 0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818]])
>>>
>>> exit()
(base) yongqiang@yongqiang:~$

3. Example

https://github.com/karpathy/llama2.c/blob/master/model.py

复制代码
import torch

idxs = torch.randn(1, 5)
print("idxs.shape:", idxs.shape)
print("idxs:\n", idxs)

next_idx = torch.randn(1, 1)
print("\nnext_idx.shape:", next_idx.shape)
print("next_idx:\n", next_idx)

print("\nidxs.size(1):", idxs.size(1))
idxs_set = torch.cat((idxs, next_idx), dim=1)
print("\nidxs_set.shape:", idxs_set.shape)
print("idxs_set:\n", idxs_set)

/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py 
idxs.shape: torch.Size([1, 5])
idxs:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691]])

next_idx.shape: torch.Size([1, 1])
next_idx:
 tensor([[0.4807]])

idxs.size(1): 5

idxs_set.shape: torch.Size([1, 6])
idxs_set:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691,  0.4807]])

Process finished with exit code 0

References

1\] Yongqiang Cheng,

相关推荐
隔壁大炮12 小时前
Day06-08.CNN概述介绍
人工智能·pytorch·深度学习·算法·计算机视觉·cnn·numpy
QiZhang | UESTC12 小时前
从基础 RoPE 到 YaRN:源码学习路线揭秘
pytorch·深度学习·学习
光之后裔14 小时前
Numpy以及Pytorch中多维数组的维度数与维度值以及轴axis理解
pytorch·python·numpy
Jmayday16 小时前
Pytorch:神经网络基础
人工智能·pytorch·神经网络
Cho1yon17 小时前
【AI Agent 第十期:基于 scrcpy + PyTorch 的车载系统多屏自动化测试工具开发】
人工智能·pytorch·ui·车载系统·自动化
蓝博AI19 小时前
基于深度学习的蔬菜识别系统,resnet50,vgg16,resnet34【pytorch框架,python代码】
人工智能·pytorch·python·深度学习·机器学习·cnn
努力学习_小白20 小时前
DenseNet——Pytorch学习记录
人工智能·pytorch·机器学习·densenet
用户990193052451 天前
Nano-vLLM-MS:基于 nano-vLLM ,支持 MoE 模型和 Speculative Decoding
pytorch·llm
eqwaak02 天前
PyTorch入门:10分钟搭建首个神经网络
开发语言·人工智能·pytorch·python
IRevers2 天前
【Agent】基于Langchain的Agent数据库查询助手
数据库·人工智能·pytorch·sql·深度学习·langchain·agent