PyTorch torch.cat

PyTorch torch.cat

  • [1. `torch.cat`](#1. torch.cat)
  • [2. Example](#2. Example)
  • [3. Example](#3. Example)
  • References

torch
https://pytorch.org/docs/stable/torch.html

  • torch.cat (Python function, in torch.cat)

1. torch.cat

https://pytorch.org/docs/stable/generated/torch.cat.html

复制代码
torch.cat(tensors, dim=0, *, out=None) -> Tensor

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).

在给定维度上连接给定的 seq 张量序列。所有张量必须具有相同的形状 (连接维度除外),或者是一个大小为 (0,) 的一维空张量。

torch.cat() can be seen as an inverse operation for torch.split() and torch.chunk().
torch.cat() 可以看作是 torch.split()torch.chunk() 的逆运算。

torch.cat() can be best understood via examples.

torch.stack() concatenates the given sequence along a new dimension.
torch.stack() 沿着新维度连接给定的序列。

  • Parameters

tensors (sequence of Tensors) - any python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.

任何相同类型的张量 Python 序列。提供的非空张量必须具有相同的形状,连接维度除外。

dim (int, optional) - the dimension over which the tensors are concatenated

连接张量的维度

  • Keyword Arguments

out (Tensor, optional) - the output tensor.

2. Example

复制代码
(base) yongqiang@yongqiang:~$ python
Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 0)
tensor([[ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818],
        [ 0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818]])
>>>
>>> torch.cat((x, x, x), 1)
tensor([[ 0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260,  0.0811,  0.4571, -1.5260],
        [ 1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818,  1.4803, -0.0314, -1.5818]])
>>>
>>> exit()
(base) yongqiang@yongqiang:~$

3. Example

https://github.com/karpathy/llama2.c/blob/master/model.py

复制代码
import torch

idxs = torch.randn(1, 5)
print("idxs.shape:", idxs.shape)
print("idxs:\n", idxs)

next_idx = torch.randn(1, 1)
print("\nnext_idx.shape:", next_idx.shape)
print("next_idx:\n", next_idx)

print("\nidxs.size(1):", idxs.size(1))
idxs_set = torch.cat((idxs, next_idx), dim=1)
print("\nidxs_set.shape:", idxs_set.shape)
print("idxs_set:\n", idxs_set)

/home/yongqiang/miniconda3/bin/python /home/yongqiang/llm_work/llama2.c/yongqiang.py 
idxs.shape: torch.Size([1, 5])
idxs:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691]])

next_idx.shape: torch.Size([1, 1])
next_idx:
 tensor([[0.4807]])

idxs.size(1): 5

idxs_set.shape: torch.Size([1, 6])
idxs_set:
 tensor([[-1.3383,  0.1427,  0.0857,  2.2887,  0.1691,  0.4807]])

Process finished with exit code 0

References

1\] Yongqiang Cheng,

相关推荐
誉鏐2 小时前
PyTorch复现逻辑回归
人工智能·pytorch·逻辑回归
意.远3 小时前
在PyTorch中使用GPU加速:从基础操作到模型部署
人工智能·pytorch·python·深度学习
byxdaz11 小时前
PyTorch中Linear全连接层
pytorch
Start_Present11 小时前
Pytorch 第十二回:循环神经网络——LSTM模型
pytorch·rnn·神经网络·数据分析·lstm
船长@Quant15 小时前
PyTorch量化进阶教程:第六章 模型部署与生产化
pytorch·python·深度学习·transformer·量化交易·sklearn·ta-lib
byxdaz15 小时前
PyTorch中卷积层torch.nn.Conv2d
pytorch
进取星辰17 小时前
PyTorch 深度学习实战(32):多模态学习与CLIP模型
pytorch·深度学习·学习
带娃的IT创业者18 小时前
《Python实战进阶》No39:模型部署——TensorFlow Serving 与 ONNX
pytorch·python·tensorflow·持续部署
iiimZoey1 天前
配置晟腾910b的PyTorch torch_npu环境
pytorch
Start_Present1 天前
Pytorch 第十三回:神经网络编码器——自动编解码器
pytorch·python·深度学习·神经网络