torch.cat 使用小节

torch.cat 使用小节

torch.cat 要求在所指定拼接维度之外的所有维度都要匹配,例如

python 复制代码
import torch
v1 = torch.tensor([[1, 2, 3], [4, 5, 6], [4, 5, 6]])  # 3*3
v2 = torch.tensor([[3, 6, 8]])  # 1*3
torch.cat([v1, v2], dim=1)

运行之后会报错

Sizes of tensors must match except in dimension 1. Expected size 3 but got size 1 for tensor number 1 in the list.

这就是因为这两个向量的第 0 个维度不相等,故无法完成 cat 操作,改成 dim=0 即可得到输出:

python 复制代码
tensor([[1, 2, 3],
        [4, 5, 6],
        [4, 5, 6],
        [3, 6, 8]])

这种操作实际上完成了多个 batch 间数据的合并,若想完成同个 batch 内数据的 cat,要保证第 0 个维度大小一致,即 batchsize 相等。

python 复制代码
import torch

v1 = torch.tensor([[1, 2, 3, 3], [4, 5, 6, 6]])  # 2*4
v2 = torch.tensor([[6, 6], [8, 8]])  # 1*3
torch.cat([v1, v2], dim=1)

可得到预期输出:

python 复制代码
tensor([[1, 2, 3, 3, 6, 6],
        [4, 5, 6, 6, 8, 8]])
相关推荐
2402_854808371 分钟前
CSS如何实现元素在容器内居中_利用margin-auto技巧
jvm·数据库·python
weixin_580614002 分钟前
html标签怎么表示用户输入_kbd标签键盘快捷键标注【介绍】
jvm·数据库·python
m0_716430073 分钟前
如何监控集群 interconnect_ping与traceroute验证心跳通畅.txt
jvm·数据库·python
m0_678485454 分钟前
如何通过 curl 调用 Go 标准库 RPC 服务(JSON-RPC 协议)
jvm·数据库·python
2401_8654396316 分钟前
HTML5中SVG原生动画标签Animate的基础用法
jvm·数据库·python
萝卜小白17 分钟前
算法实习day03-碎碎念
python·ai·实习
XY_墨莲伊19 分钟前
【实战项目】基于B/S结构Flask+Folium技术的出租车轨迹可视化分析系统(文末含完整源代码)
开发语言·后端·python·算法·机器学习·flask
Trisyp24 分钟前
使用 APScheduler 实现精细化的定时任务
python·apscheduler
z64943150828 分钟前
【Python开源-单目测距】单目无人机多视角测距:DJI RTK图像 → 地面目标3D坐标与距离,平均RE仅2.12%
python·计算机视觉·开源·无人机
Fleshy数模28 分钟前
PyQt5 登录界面开发全流程:从环境配置到可视化设计
开发语言·python·qt