torch.cat 与torch.stack的区别

torch.cat 与torch.stack的区别
torch.cat
  • 定义:按照规定的维度进行拼接。
  • 实际使用: 例如使用BiLSTM时,将两个方向的向量进行叠加,就是用torch.cat。
python 复制代码
import torch

forward_lstm = torch.randn((2, 10, 3))
backward_lstm = torch.randn((2, 10, 3))

lstm_emd = torch.cat((forward_lstm, backward_lstm), dim=-1)

print(lstm_emd.size())
'''
torch.Size([2, 10, 6])
'''
torch.stack
  • 定义:官方解释是在新的dim上进行叠加。叠加的意思就是增加一个维度。
  • 本质:对张量进行unsqueeze(dim)之后,再进行torch.cat(dim=dim)操作。
  • 实际使用:将张量合在一起,形成一个batch。
python 复制代码
import torch

batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))
batch = torch.stack((batch_1, batch_2), dim=0) 
print(batch.size()) 
'''
torch.Size([2, 10, 3])
'''
  • 使用torch.unsqueeze 和torch.cat实现torch.stack功能
python 复制代码
import torch

batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))

batch_1 = torch.unsuqeeze(batch_1, dim=0)
batch_2 = torch.unsuqeeze(batch_2, dim=0)
batch = torch.cat((batch_1, batch_2), dim=0) 
print(batch.size()) 
'''
torch.Size([2, 10, 3])
'''
相关推荐
聆风吟º2 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子2 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder2 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能3 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5773 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
猫头虎3 小时前
如何排查并解决项目启动时报错Error encountered while processing: java.io.IOException: closed 的问题
java·开发语言·jvm·spring boot·python·开源·maven
h64648564h3 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切3 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
八零后琐话4 小时前
干货:程序员必备性能分析工具——Arthas火焰图
开发语言·python