torch.cat 与torch.stack的区别

torch.cat 与torch.stack的区别
torch.cat
  • 定义:按照规定的维度进行拼接。
  • 实际使用: 例如使用BiLSTM时,将两个方向的向量进行叠加,就是用torch.cat。
python 复制代码
import torch

forward_lstm = torch.randn((2, 10, 3))
backward_lstm = torch.randn((2, 10, 3))

lstm_emd = torch.cat((forward_lstm, backward_lstm), dim=-1)

print(lstm_emd.size())
'''
torch.Size([2, 10, 6])
'''
torch.stack
  • 定义:官方解释是在新的dim上进行叠加。叠加的意思就是增加一个维度。
  • 本质:对张量进行unsqueeze(dim)之后,再进行torch.cat(dim=dim)操作。
  • 实际使用:将张量合在一起,形成一个batch。
python 复制代码
import torch

batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))
batch = torch.stack((batch_1, batch_2), dim=0) 
print(batch.size()) 
'''
torch.Size([2, 10, 3])
'''
  • 使用torch.unsqueeze 和torch.cat实现torch.stack功能
python 复制代码
import torch

batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))

batch_1 = torch.unsuqeeze(batch_1, dim=0)
batch_2 = torch.unsuqeeze(batch_2, dim=0)
batch = torch.cat((batch_1, batch_2), dim=0) 
print(batch.size()) 
'''
torch.Size([2, 10, 3])
'''
相关推荐
船长Talk3 分钟前
NumPy+Pandas数据分析基础完全指南
python
Wyz201210243 分钟前
宝塔面板安装后显示无法连接数据库_检查MySQL服务状态
jvm·数据库·python
bryant_meng4 分钟前
【Reading Notes】(8.7)Favorite Articles from 2025 July
人工智能·深度学习·agi·资讯
2301_777599375 分钟前
Redis如何优化大量对象存储_利用Hash结构减少内存碎片占用
jvm·数据库·python
2301_777599376 分钟前
Python怎么解压tar.gz_tarfile模块提取打包文件操作
jvm·数据库·python
Satellite-GNSS8 分钟前
深度学习编程框架全体系详解(含选型指南+核心对比)
人工智能·深度学习
小白学大数据8 分钟前
Python 爬取图片攻略:告别水印,批量保存高清图片
开发语言·python
乔江seven9 分钟前
【李沐 | 动手学深度学习】11-1 现代卷积神经网络-AlexNet
人工智能·深度学习·卷积神经网络·alexnet·深度神经网络
2301_8152795210 分钟前
HTML怎么标注密钥权限范围_HTML “仅读取用户信息”说明【操作】
jvm·数据库·python
m0_6784854511 分钟前
Go语言怎么用Jaeger_Go语言Jaeger链路追踪教程【实用】
jvm·数据库·python