【AI | pytorch】torch.view_as_real的使用

torch.view_as_real使用

python 复制代码
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)

这行代码涉及 PyTorch 张量的操作,逐步解析如下:

1. xq_ * freqs_cis

  • xq_freqs_cis 是两个 PyTorch 张量。
  • 假设 xq_ 是一个复数张量,形状为 (batch_size, ..., complex_dim),其中 complex_dim = 2(因为 PyTorch 中复数可以表示为两个实数的组合)。
  • freqs_cis 是一个与 xq_ 广播兼容的张量,也可能是复数。
  • * 是逐元素乘法。对于复数张量,乘法遵循复数乘法规则。

2. torch.view_as_real(...)

  • torch.view_as_real(tensor) 将复数张量转换为实数表示。对于形状为 (batch_size, ..., complex_dim) 的复数张量,其转换结果为 (batch_size, ..., 2),其中最后一个维度 2 表示实部和虚部。

例如:

python 复制代码
z = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
real_view = torch.view_as_real(z)
# real_view 的形状为 (2, 2)
# real_view 内容为 [[1.0, 2.0], [3.0, 4.0]],分别是实部和虚部

3. .flatten(3)

  • flatten(d) 方法会将张量的第 d 维及其之后的所有维度展平为一个维度。

假设 torch.view_as_real(xq_ * freqs_cis) 的输出形状是 (batch_size, dim1, dim2, 2)

  • .flatten(3) 会将第 3 维及之后的维度展平,结果是 (batch_size, dim1, dim2 * 2)

结合起来的解释

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 的流程:

  1. xq_freqs_cis 做逐元素乘法,得到一个复数张量。
  2. 将结果转换为实数表示,形状最后增加一维用于存储实部和虚部。
  3. 将第 3 维及之后的维度展平,以减少维度数。

最终的 xq_out 是一个纯实数张量,其维度取决于输入张量的形状和广播规则。

复数相乘

在 PyTorch 中,xq_ * freqs_cis 的逐元素乘法是基于复数张量的复数乘法规则进行的。复数张量的每个元素由实部和虚部组成,计算时遵循标准复数乘法规则:

复数乘法公式

对于两个复数 ( z_1 = a + bi ) 和 ( z_2 = c + di ),它们的乘积是:

z_1 \\cdot z_2 = (a \\cdot c - b \\cdot d) + (a \\cdot d + b \\cdot c)i

其中:

  • ( a, b ) 是 ( z_1 ) 的实部和虚部。
  • ( c, d ) 是 ( z_2 ) 的实部和虚部。

PyTorch 的复数张量

在 PyTorch 中,复数张量的每个元素是复数,由实部和虚部组成。计算时,PyTorch 会根据复数的定义自动执行上述运算。

例如:

python 复制代码
xq_ = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
freqs_cis = torch.tensor([0.5 + 0.5j, 1 + 1j], dtype=torch.complex64)

result = xq_ * freqs_cis
print(result)
# 输出: tensor([ -0.5 + 1.5j, -1.0 + 7.0j], dtype=torch.complex64)

计算细节

逐元素计算的每一对元素如下:

  1. 对于第一个元素:

    (1 + 2i) \\cdot (0.5 + 0.5i) = (1 \\cdot 0.5 - 2 \\cdot 0.5) + (1 \\cdot 0.5 + 2 \\cdot 0.5)i = -0.5 + 1.5i

  2. 对于第二个元素:

    (3 + 4i) \\cdot (1 + 1i) = (3 \\cdot 1 - 4 \\cdot 1) + (3 \\cdot 1 + 4 \\cdot 1)i = -1 + 7i

广播规则

如果 xq_freqs_cis 的形状不同,PyTorch 会尝试广播它们,使得形状匹配。广播规则包括:

  1. 右对齐形状,从后向前匹配每个维度。
  2. 如果维度不一致且其中一个维度为 1,则扩展为另一个维度的大小。
  3. 如果两者维度都不匹配且不为 1,则广播失败。

例如:

python 复制代码
xq_ = torch.tensor([[1 + 2j, 3 + 4j]], dtype=torch.complex64)  # (1, 2)
freqs_cis = torch.tensor([0.5 + 0.5j], dtype=torch.complex64)  # (1,)

result = xq_ * freqs_cis  # 广播为 (1, 2)

总结

xq_ * freqs_cis 是基于复数逐元素相乘的操作:

  1. 每个对应元素按复数乘法规则计算。
  2. 如果形状不同,会应用广播规则,使张量形状匹配。
  3. 结果是一个新的复数张量,与广播后形状一致。
相关推荐
Tech Synapse17 分钟前
端到端自动驾驶系统实战指南:从Comma.ai架构到PyTorch部署
人工智能·pytorch·自动驾驶·carla·end-to-end
红衣小蛇妖21 分钟前
Python基础学习-Day30
开发语言·python·学习
珂朵莉MM22 分钟前
2023 睿抗机器人开发者大赛CAIP-编程技能赛-本科组(国赛) 解题报告 | 珂学家
人工智能·算法·职场和发展·深度优先·图论
闭月之泪舞23 分钟前
OpenCv高阶(十五)——EigenFace人脸识别
人工智能·opencv·计算机视觉
CodeCraft Studio32 分钟前
PDF处理控件Aspose.PDF教程:以编程方式将 PDF 导出为 JPG
java·python·pdf·.net
于归pro44 分钟前
Python环境管理工具深度指南:pip、Poetry、uv、Conda
python·pip·uv
追光天使1 小时前
如何利用 Conda 安装 Pytorch 教程 ?
人工智能·pytorch·conda
鸭鸭鸭进京赶烤1 小时前
第九届电子信息技术与计算机工程国际学术会议(EITCE 2025)
人工智能·计算机视觉·ai·云计算·aigc·mybatis·制造
LabVIEW开发1 小时前
LabVIEW下AI开发
人工智能·labview
视觉&物联智能1 小时前
【杂谈】-智领安全新篇:人工智能驱动现代物理安全防护体系
人工智能·深度学习·安全·aigc·agi