pytorch计算张量中三维向量的欧式距离

如果 X 是一个包含多个三维向量的张量,形状为 [b, n, 3],其中 b 是批次大小,n 是每个批次中的向量数量,那么可以使用类似的广播机制来计算同一批次内不同位置的三维向量之间的欧式距离。

以下是具体实现步骤:

  1. 扩展张量的维度 :需要将 X 的维度扩展,以便能够利用广播机制计算每对向量之间的差值。

  2. 计算差值并求平方和:计算向量之间的差值,并对差值的平方求和。

  3. 计算欧式距离:对平方和取平方根,得到欧式距离。

    import torch

    假设 X 是形状为 [b, n, 3] 的张量,b 是批次大小,n 是向量的数量

    b = 128
    n = 100
    X = torch.randn(b, n, 3) # 示例输入

    第一步:扩展维度

    X_expanded_1 = X.unsqueeze(2) # 形状为 [b, n, 1, 3]
    X_expanded_2 = X.unsqueeze(1) # 形状为 [b, 1, n, 3]

    第二步:计算每对向量之间的差值的平方和

    dX = X_expanded_1 - X_expanded_2 # 形状为 [b, n, n, 3]
    dX_squared_sum = torch.sum(dX**2, dim=3) # 形状为 [b, n, n]

    第三步:计算欧式距离

    distances = torch.sqrt(dX_squared_sum) # 形状为 [b, n, n]

    distances[k, i, j] 表示批次 k 中位置 i 和位置 j 之间的欧式距离

    print(distances)

解释:

  1. 扩展维度X.unsqueeze(2)X 的形状从 [b, n, 3] 变为 [b, n, 1, 3],而 X.unsqueeze(1) 将其形状变为 [b, 1, n, 3]。通过这种扩展,每个批次内的所有位置对可以使用广播机制进行差值计算。

  2. 计算差值并求平方和dX 是一个形状为 [b, n, n, 3] 的张量,表示每个批次内的每对位置之间的差值。torch.sum(dX**2, dim=3) 对最后一个维度(即三维坐标的维度)求和,得到每对位置之间的平方距离,形状为 [b, n, n]

  3. 计算欧式距离 :最后,使用 torch.sqrt 对平方距离取平方根,得到最终的欧式距离矩阵 distances,其形状为 [b, n, n],表示每个批次内所有位置对之间的欧式距离。

这个 distances 张量的形状为 [b, n, n],其中 distances[k, i, j] 表示批次 k 中位置 i 和位置 j 之间的欧式距离。

相关推荐
聆风吟º11 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
寻星探路12 小时前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
Codebee13 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º14 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys14 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_567814 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子14 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder14 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能15 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_1601448715 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能