torch.cat 与torch.stack的区别

torch.cat 与torch.stack的区别
torch.cat
  • 定义:按照规定的维度进行拼接。
  • 实际使用: 例如使用BiLSTM时,将两个方向的向量进行叠加,就是用torch.cat。
python 复制代码
import torch

forward_lstm = torch.randn((2, 10, 3))
backward_lstm = torch.randn((2, 10, 3))

lstm_emd = torch.cat((forward_lstm, backward_lstm), dim=-1)

print(lstm_emd.size())
'''
torch.Size([2, 10, 6])
'''
torch.stack
  • 定义:官方解释是在新的dim上进行叠加。叠加的意思就是增加一个维度。
  • 本质:对张量进行unsqueeze(dim)之后,再进行torch.cat(dim=dim)操作。
  • 实际使用:将张量合在一起,形成一个batch。
python 复制代码
import torch

batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))
batch = torch.stack((batch_1, batch_2), dim=0) 
print(batch.size()) 
'''
torch.Size([2, 10, 3])
'''
  • 使用torch.unsqueeze 和torch.cat实现torch.stack功能
python 复制代码
import torch

batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))

batch_1 = torch.unsuqeeze(batch_1, dim=0)
batch_2 = torch.unsuqeeze(batch_2, dim=0)
batch = torch.cat((batch_1, batch_2), dim=0) 
print(batch.size()) 
'''
torch.Size([2, 10, 3])
'''
相关推荐
ZTLJQ7 小时前
序列化的艺术:Python JSON处理完全解析
开发语言·python·json
H5css�海秀7 小时前
今天是自学大模型的第一天(sanjose)
后端·python·node.js·php
阿贵---8 小时前
使用XGBoost赢得Kaggle比赛
jvm·数据库·python
无敌昊哥战神8 小时前
【LeetCode 257】二叉树的所有路径(回溯法/深度优先遍历)- Python/C/C++详细题解
c语言·c++·python·leetcode·深度优先
李昊哲小课9 小时前
第1章-PySide6 基础认知与环境配置
python·pyqt·pyside
老鱼说AI9 小时前
大规模并发处理器程序设计(PMPP)讲解(CUDA架构):第四期:计算架构与调度
c语言·深度学习·算法·架构·cuda
2401_8942419210 小时前
用Pygame开发你的第一个小游戏
jvm·数据库·python
Hello.Reader11 小时前
深度学习 — 从人工智能到深度学习的演进之路(一)
人工智能·深度学习
Zzzz_my11 小时前
正则表达式(RE)
pytorch·python·正则表达式
天天鸭11 小时前
前端仔写了个 AI Agent,才发现大模型只干了 10% 的活
前端·python·ai编程