【AI | pytorch】torch.view_as_real的使用

torch.view_as_real使用

python 复制代码
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)

这行代码涉及 PyTorch 张量的操作,逐步解析如下:

1. xq_ * freqs_cis

  • xq_freqs_cis 是两个 PyTorch 张量。
  • 假设 xq_ 是一个复数张量,形状为 (batch_size, ..., complex_dim),其中 complex_dim = 2(因为 PyTorch 中复数可以表示为两个实数的组合)。
  • freqs_cis 是一个与 xq_ 广播兼容的张量,也可能是复数。
  • * 是逐元素乘法。对于复数张量,乘法遵循复数乘法规则。

2. torch.view_as_real(...)

  • torch.view_as_real(tensor) 将复数张量转换为实数表示。对于形状为 (batch_size, ..., complex_dim) 的复数张量,其转换结果为 (batch_size, ..., 2),其中最后一个维度 2 表示实部和虚部。

例如:

python 复制代码
z = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
real_view = torch.view_as_real(z)
# real_view 的形状为 (2, 2)
# real_view 内容为 [[1.0, 2.0], [3.0, 4.0]],分别是实部和虚部

3. .flatten(3)

  • flatten(d) 方法会将张量的第 d 维及其之后的所有维度展平为一个维度。

假设 torch.view_as_real(xq_ * freqs_cis) 的输出形状是 (batch_size, dim1, dim2, 2)

  • .flatten(3) 会将第 3 维及之后的维度展平,结果是 (batch_size, dim1, dim2 * 2)

结合起来的解释

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 的流程:

  1. xq_freqs_cis 做逐元素乘法,得到一个复数张量。
  2. 将结果转换为实数表示,形状最后增加一维用于存储实部和虚部。
  3. 将第 3 维及之后的维度展平,以减少维度数。

最终的 xq_out 是一个纯实数张量,其维度取决于输入张量的形状和广播规则。

复数相乘

在 PyTorch 中,xq_ * freqs_cis 的逐元素乘法是基于复数张量的复数乘法规则进行的。复数张量的每个元素由实部和虚部组成,计算时遵循标准复数乘法规则:

复数乘法公式

对于两个复数 ( z_1 = a + bi ) 和 ( z_2 = c + di ),它们的乘积是:

z_1 \\cdot z_2 = (a \\cdot c - b \\cdot d) + (a \\cdot d + b \\cdot c)i

其中:

  • ( a, b ) 是 ( z_1 ) 的实部和虚部。
  • ( c, d ) 是 ( z_2 ) 的实部和虚部。

PyTorch 的复数张量

在 PyTorch 中,复数张量的每个元素是复数,由实部和虚部组成。计算时,PyTorch 会根据复数的定义自动执行上述运算。

例如:

python 复制代码
xq_ = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
freqs_cis = torch.tensor([0.5 + 0.5j, 1 + 1j], dtype=torch.complex64)

result = xq_ * freqs_cis
print(result)
# 输出: tensor([ -0.5 + 1.5j, -1.0 + 7.0j], dtype=torch.complex64)

计算细节

逐元素计算的每一对元素如下:

  1. 对于第一个元素:

    (1 + 2i) \\cdot (0.5 + 0.5i) = (1 \\cdot 0.5 - 2 \\cdot 0.5) + (1 \\cdot 0.5 + 2 \\cdot 0.5)i = -0.5 + 1.5i

  2. 对于第二个元素:

    (3 + 4i) \\cdot (1 + 1i) = (3 \\cdot 1 - 4 \\cdot 1) + (3 \\cdot 1 + 4 \\cdot 1)i = -1 + 7i

广播规则

如果 xq_freqs_cis 的形状不同,PyTorch 会尝试广播它们,使得形状匹配。广播规则包括:

  1. 右对齐形状,从后向前匹配每个维度。
  2. 如果维度不一致且其中一个维度为 1,则扩展为另一个维度的大小。
  3. 如果两者维度都不匹配且不为 1,则广播失败。

例如:

python 复制代码
xq_ = torch.tensor([[1 + 2j, 3 + 4j]], dtype=torch.complex64)  # (1, 2)
freqs_cis = torch.tensor([0.5 + 0.5j], dtype=torch.complex64)  # (1,)

result = xq_ * freqs_cis  # 广播为 (1, 2)

总结

xq_ * freqs_cis 是基于复数逐元素相乘的操作:

  1. 每个对应元素按复数乘法规则计算。
  2. 如果形状不同,会应用广播规则,使张量形状匹配。
  3. 结果是一个新的复数张量,与广播后形状一致。
相关推荐
BeerBear30 分钟前
【保姆级教程-从0开始开发MCP服务器】一、MCP学习压根没有你想象得那么难!.md
人工智能·mcp
小气小憩1 小时前
“暗战”百度搜索页:Monica悬浮球被“围剿”,一场AI Agent与传统巨头的流量攻防战
前端·人工智能
数据智能老司机1 小时前
精通 Python 设计模式——创建型设计模式
python·设计模式·架构
神经星星1 小时前
准确度提升400%!印度季风预测模型基于36个气象站点,实现城区尺度精细预报
人工智能
数据智能老司机2 小时前
精通 Python 设计模式——SOLID 原则
python·设计模式·架构
c8i3 小时前
django中的FBV 和 CBV
python·django
c8i3 小时前
python中的闭包和装饰器
python
IT_陈寒4 小时前
JavaScript 性能优化:5 个被低估的 V8 引擎技巧让你的代码快 200%
前端·人工智能·后端
Juchecar4 小时前
一文讲清 PyTorch 中反向传播(Backpropagation)的实现原理
人工智能
黎燃4 小时前
游戏NPC的智能行为设计:从规则驱动到强化学习的演进
人工智能