torch.view_as_real使用
python
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
这行代码涉及 PyTorch 张量的操作,逐步解析如下:
1. xq_ * freqs_cis
xq_
和freqs_cis
是两个 PyTorch 张量。- 假设
xq_
是一个复数张量,形状为(batch_size, ..., complex_dim)
,其中complex_dim = 2
(因为 PyTorch 中复数可以表示为两个实数的组合)。 freqs_cis
是一个与xq_
广播兼容的张量,也可能是复数。*
是逐元素乘法。对于复数张量,乘法遵循复数乘法规则。
2. torch.view_as_real(...)
torch.view_as_real(tensor)
将复数张量转换为实数表示。对于形状为(batch_size, ..., complex_dim)
的复数张量,其转换结果为(batch_size, ..., 2)
,其中最后一个维度2
表示实部和虚部。
例如:
python
z = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
real_view = torch.view_as_real(z)
# real_view 的形状为 (2, 2)
# real_view 内容为 [[1.0, 2.0], [3.0, 4.0]],分别是实部和虚部
3. .flatten(3)
flatten(d)
方法会将张量的第d
维及其之后的所有维度展平为一个维度。
假设 torch.view_as_real(xq_ * freqs_cis)
的输出形状是 (batch_size, dim1, dim2, 2)
:
.flatten(3)
会将第3
维及之后的维度展平,结果是(batch_size, dim1, dim2 * 2)
。
结合起来的解释
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
的流程:
- 对
xq_
和freqs_cis
做逐元素乘法,得到一个复数张量。 - 将结果转换为实数表示,形状最后增加一维用于存储实部和虚部。
- 将第 3 维及之后的维度展平,以减少维度数。
最终的 xq_out
是一个纯实数张量,其维度取决于输入张量的形状和广播规则。
复数相乘
在 PyTorch 中,xq_ * freqs_cis
的逐元素乘法是基于复数张量的复数乘法规则进行的。复数张量的每个元素由实部和虚部组成,计算时遵循标准复数乘法规则:
复数乘法公式
对于两个复数 ( z_1 = a + bi ) 和 ( z_2 = c + di ),它们的乘积是:
[
z_1 \cdot z_2 = (a \cdot c - b \cdot d) + (a \cdot d + b \cdot c)i
]
其中:
- ( a, b ) 是 ( z_1 ) 的实部和虚部。
- ( c, d ) 是 ( z_2 ) 的实部和虚部。
PyTorch 的复数张量
在 PyTorch 中,复数张量的每个元素是复数,由实部和虚部组成。计算时,PyTorch 会根据复数的定义自动执行上述运算。
例如:
python
xq_ = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
freqs_cis = torch.tensor([0.5 + 0.5j, 1 + 1j], dtype=torch.complex64)
result = xq_ * freqs_cis
print(result)
# 输出: tensor([ -0.5 + 1.5j, -1.0 + 7.0j], dtype=torch.complex64)
计算细节
逐元素计算的每一对元素如下:
- 对于第一个元素:
[
(1 + 2i) \cdot (0.5 + 0.5i) = (1 \cdot 0.5 - 2 \cdot 0.5) + (1 \cdot 0.5 + 2 \cdot 0.5)i = -0.5 + 1.5i
] - 对于第二个元素:
[
(3 + 4i) \cdot (1 + 1i) = (3 \cdot 1 - 4 \cdot 1) + (3 \cdot 1 + 4 \cdot 1)i = -1 + 7i
]
广播规则
如果 xq_
和 freqs_cis
的形状不同,PyTorch 会尝试广播它们,使得形状匹配。广播规则包括:
- 右对齐形状,从后向前匹配每个维度。
- 如果维度不一致且其中一个维度为
1
,则扩展为另一个维度的大小。 - 如果两者维度都不匹配且不为
1
,则广播失败。
例如:
python
xq_ = torch.tensor([[1 + 2j, 3 + 4j]], dtype=torch.complex64) # (1, 2)
freqs_cis = torch.tensor([0.5 + 0.5j], dtype=torch.complex64) # (1,)
result = xq_ * freqs_cis # 广播为 (1, 2)
总结
xq_ * freqs_cis
是基于复数逐元素相乘的操作:
- 每个对应元素按复数乘法规则计算。
- 如果形状不同,会应用广播规则,使张量形状匹配。
- 结果是一个新的复数张量,与广播后形状一致。