torch.cat 使用小节

torch.cat 使用小节

torch.cat 要求在所指定拼接维度之外的所有维度都要匹配,例如

python 复制代码
import torch
v1 = torch.tensor([[1, 2, 3], [4, 5, 6], [4, 5, 6]])  # 3*3
v2 = torch.tensor([[3, 6, 8]])  # 1*3
torch.cat([v1, v2], dim=1)

运行之后会报错

Sizes of tensors must match except in dimension 1. Expected size 3 but got size 1 for tensor number 1 in the list.

这就是因为这两个向量的第 0 个维度不相等,故无法完成 cat 操作,改成 dim=0 即可得到输出:

python 复制代码
tensor([[1, 2, 3],
        [4, 5, 6],
        [4, 5, 6],
        [3, 6, 8]])

这种操作实际上完成了多个 batch 间数据的合并,若想完成同个 batch 内数据的 cat,要保证第 0 个维度大小一致,即 batchsize 相等。

python 复制代码
import torch

v1 = torch.tensor([[1, 2, 3, 3], [4, 5, 6, 6]])  # 2*4
v2 = torch.tensor([[6, 6], [8, 8]])  # 1*3
torch.cat([v1, v2], dim=1)

可得到预期输出:

python 复制代码
tensor([[1, 2, 3, 3, 6, 6],
        [4, 5, 6, 6, 8, 8]])
相关推荐
星星也在雾里3 分钟前
Anaconda命令行配置Jupyter Notebook虚拟环境
python·jupyter
极光代码工作室3 分钟前
基于机器学习的信用卡欺诈检测系统设计
人工智能·python·深度学习·机器学习
quetalangtaosha5 分钟前
Anomaly Detection系列(CVPR2025 EG-MPC论文解读)
人工智能·深度学习·计算机视觉
迷藏4949 分钟前
**超融合架构下的Go语言实践:从零搭建高性能容器化微服务集群**在现代云原生时代,*
java·python·云原生·架构·golang
深度学习lover14 分钟前
<数据集>yolo 船舶识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·船舶分类识别
测试秃头怪16 分钟前
Python+selenium搭建Web自动化测试框架
自动化测试·软件测试·python·selenium·测试工具·职场和发展·测试用例
Irene199116 分钟前
PyCharm 改字体大小
python·pycharm
昆曲之源_娄江河畔17 分钟前
婴儿版GPT
python·gpt·ai·transformer
无边风月-风之羽翼20 分钟前
omnilingual_asr在Nvidia Spark DGX中部署
python
蓝天守卫者联盟124 分钟前
烧结机一氧化碳治理厂家技术路线与市场格局分析
大数据·人工智能·python