【AI | pytorch】torch.view_as_real的使用

torch.view_as_real使用

python 复制代码
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)

这行代码涉及 PyTorch 张量的操作,逐步解析如下:

1. xq_ * freqs_cis

  • xq_freqs_cis 是两个 PyTorch 张量。
  • 假设 xq_ 是一个复数张量,形状为 (batch_size, ..., complex_dim),其中 complex_dim = 2(因为 PyTorch 中复数可以表示为两个实数的组合)。
  • freqs_cis 是一个与 xq_ 广播兼容的张量,也可能是复数。
  • * 是逐元素乘法。对于复数张量,乘法遵循复数乘法规则。

2. torch.view_as_real(...)

  • torch.view_as_real(tensor) 将复数张量转换为实数表示。对于形状为 (batch_size, ..., complex_dim) 的复数张量,其转换结果为 (batch_size, ..., 2),其中最后一个维度 2 表示实部和虚部。

例如:

python 复制代码
z = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
real_view = torch.view_as_real(z)
# real_view 的形状为 (2, 2)
# real_view 内容为 [[1.0, 2.0], [3.0, 4.0]],分别是实部和虚部

3. .flatten(3)

  • flatten(d) 方法会将张量的第 d 维及其之后的所有维度展平为一个维度。

假设 torch.view_as_real(xq_ * freqs_cis) 的输出形状是 (batch_size, dim1, dim2, 2)

  • .flatten(3) 会将第 3 维及之后的维度展平,结果是 (batch_size, dim1, dim2 * 2)

结合起来的解释

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 的流程:

  1. xq_freqs_cis 做逐元素乘法,得到一个复数张量。
  2. 将结果转换为实数表示,形状最后增加一维用于存储实部和虚部。
  3. 将第 3 维及之后的维度展平,以减少维度数。

最终的 xq_out 是一个纯实数张量,其维度取决于输入张量的形状和广播规则。

复数相乘

在 PyTorch 中,xq_ * freqs_cis 的逐元素乘法是基于复数张量的复数乘法规则进行的。复数张量的每个元素由实部和虚部组成,计算时遵循标准复数乘法规则:

复数乘法公式

对于两个复数 ( z_1 = a + bi ) 和 ( z_2 = c + di ),它们的乘积是:

z_1 \\cdot z_2 = (a \\cdot c - b \\cdot d) + (a \\cdot d + b \\cdot c)i

其中:

  • ( a, b ) 是 ( z_1 ) 的实部和虚部。
  • ( c, d ) 是 ( z_2 ) 的实部和虚部。

PyTorch 的复数张量

在 PyTorch 中,复数张量的每个元素是复数,由实部和虚部组成。计算时,PyTorch 会根据复数的定义自动执行上述运算。

例如:

python 复制代码
xq_ = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
freqs_cis = torch.tensor([0.5 + 0.5j, 1 + 1j], dtype=torch.complex64)

result = xq_ * freqs_cis
print(result)
# 输出: tensor([ -0.5 + 1.5j, -1.0 + 7.0j], dtype=torch.complex64)

计算细节

逐元素计算的每一对元素如下:

  1. 对于第一个元素:

    (1 + 2i) \\cdot (0.5 + 0.5i) = (1 \\cdot 0.5 - 2 \\cdot 0.5) + (1 \\cdot 0.5 + 2 \\cdot 0.5)i = -0.5 + 1.5i

  2. 对于第二个元素:

    (3 + 4i) \\cdot (1 + 1i) = (3 \\cdot 1 - 4 \\cdot 1) + (3 \\cdot 1 + 4 \\cdot 1)i = -1 + 7i

广播规则

如果 xq_freqs_cis 的形状不同,PyTorch 会尝试广播它们,使得形状匹配。广播规则包括:

  1. 右对齐形状,从后向前匹配每个维度。
  2. 如果维度不一致且其中一个维度为 1,则扩展为另一个维度的大小。
  3. 如果两者维度都不匹配且不为 1,则广播失败。

例如:

python 复制代码
xq_ = torch.tensor([[1 + 2j, 3 + 4j]], dtype=torch.complex64)  # (1, 2)
freqs_cis = torch.tensor([0.5 + 0.5j], dtype=torch.complex64)  # (1,)

result = xq_ * freqs_cis  # 广播为 (1, 2)

总结

xq_ * freqs_cis 是基于复数逐元素相乘的操作:

  1. 每个对应元素按复数乘法规则计算。
  2. 如果形状不同,会应用广播规则,使张量形状匹配。
  3. 结果是一个新的复数张量,与广播后形状一致。
相关推荐
七月shi人2 小时前
【AI编程工具IDE/CLI/插件专栏】-国外IDE与Cursor能力对比
ide·人工智能·ai编程·代码助手
橙 子_3 小时前
基于 Amazon Nova Sonic 和 MCP 构建语音交互 Agent
python
2zcode4 小时前
基于Matlab的深度学习智能行人检测与统计系统
人工智能·深度学习·目标跟踪
宇寒风暖4 小时前
Flask 框架全面详解
笔记·后端·python·学习·flask·知识
哪 吒5 小时前
【2025C卷】华为OD机试九日集训第3期 - 按算法分类,由易到难,提升编程能力和解题技巧
python·算法·华为od·华为od机试·2025c卷
weixin_464078075 小时前
机器学习sklearn:过滤
人工智能·机器学习·sklearn
weixin_464078075 小时前
机器学习sklearn:降维
人工智能·机器学习·sklearn
数据与人工智能律师5 小时前
智能合约漏洞导致的损失,法律责任应如何分配
大数据·网络·人工智能·算法·区块链
张艾拉 Fun AI Everyday5 小时前
小宿科技:AI Agent 的卖铲人
人工智能·aigc·创业创新·ai-native
zhongqu_3dnest5 小时前
三维火灾调查重建:科技赋能,探寻真相
人工智能