【AI | pytorch】torch.view_as_real的使用

torch.view_as_real使用

python 复制代码
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)

这行代码涉及 PyTorch 张量的操作,逐步解析如下:

1. xq_ * freqs_cis

  • xq_freqs_cis 是两个 PyTorch 张量。
  • 假设 xq_ 是一个复数张量,形状为 (batch_size, ..., complex_dim),其中 complex_dim = 2(因为 PyTorch 中复数可以表示为两个实数的组合)。
  • freqs_cis 是一个与 xq_ 广播兼容的张量,也可能是复数。
  • * 是逐元素乘法。对于复数张量,乘法遵循复数乘法规则。

2. torch.view_as_real(...)

  • torch.view_as_real(tensor) 将复数张量转换为实数表示。对于形状为 (batch_size, ..., complex_dim) 的复数张量,其转换结果为 (batch_size, ..., 2),其中最后一个维度 2 表示实部和虚部。

例如:

python 复制代码
z = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
real_view = torch.view_as_real(z)
# real_view 的形状为 (2, 2)
# real_view 内容为 [[1.0, 2.0], [3.0, 4.0]],分别是实部和虚部

3. .flatten(3)

  • flatten(d) 方法会将张量的第 d 维及其之后的所有维度展平为一个维度。

假设 torch.view_as_real(xq_ * freqs_cis) 的输出形状是 (batch_size, dim1, dim2, 2)

  • .flatten(3) 会将第 3 维及之后的维度展平,结果是 (batch_size, dim1, dim2 * 2)

结合起来的解释

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 的流程:

  1. xq_freqs_cis 做逐元素乘法,得到一个复数张量。
  2. 将结果转换为实数表示,形状最后增加一维用于存储实部和虚部。
  3. 将第 3 维及之后的维度展平,以减少维度数。

最终的 xq_out 是一个纯实数张量,其维度取决于输入张量的形状和广播规则。

复数相乘

在 PyTorch 中,xq_ * freqs_cis 的逐元素乘法是基于复数张量的复数乘法规则进行的。复数张量的每个元素由实部和虚部组成,计算时遵循标准复数乘法规则:

复数乘法公式

对于两个复数 ( z_1 = a + bi ) 和 ( z_2 = c + di ),它们的乘积是:

z_1 \\cdot z_2 = (a \\cdot c - b \\cdot d) + (a \\cdot d + b \\cdot c)i

其中:

  • ( a, b ) 是 ( z_1 ) 的实部和虚部。
  • ( c, d ) 是 ( z_2 ) 的实部和虚部。

PyTorch 的复数张量

在 PyTorch 中,复数张量的每个元素是复数,由实部和虚部组成。计算时,PyTorch 会根据复数的定义自动执行上述运算。

例如:

python 复制代码
xq_ = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
freqs_cis = torch.tensor([0.5 + 0.5j, 1 + 1j], dtype=torch.complex64)

result = xq_ * freqs_cis
print(result)
# 输出: tensor([ -0.5 + 1.5j, -1.0 + 7.0j], dtype=torch.complex64)

计算细节

逐元素计算的每一对元素如下:

  1. 对于第一个元素:

    (1 + 2i) \\cdot (0.5 + 0.5i) = (1 \\cdot 0.5 - 2 \\cdot 0.5) + (1 \\cdot 0.5 + 2 \\cdot 0.5)i = -0.5 + 1.5i

  2. 对于第二个元素:

    (3 + 4i) \\cdot (1 + 1i) = (3 \\cdot 1 - 4 \\cdot 1) + (3 \\cdot 1 + 4 \\cdot 1)i = -1 + 7i

广播规则

如果 xq_freqs_cis 的形状不同,PyTorch 会尝试广播它们,使得形状匹配。广播规则包括:

  1. 右对齐形状,从后向前匹配每个维度。
  2. 如果维度不一致且其中一个维度为 1,则扩展为另一个维度的大小。
  3. 如果两者维度都不匹配且不为 1,则广播失败。

例如:

python 复制代码
xq_ = torch.tensor([[1 + 2j, 3 + 4j]], dtype=torch.complex64)  # (1, 2)
freqs_cis = torch.tensor([0.5 + 0.5j], dtype=torch.complex64)  # (1,)

result = xq_ * freqs_cis  # 广播为 (1, 2)

总结

xq_ * freqs_cis 是基于复数逐元素相乘的操作:

  1. 每个对应元素按复数乘法规则计算。
  2. 如果形状不同,会应用广播规则,使张量形状匹配。
  3. 结果是一个新的复数张量,与广播后形状一致。
相关推荐
2401_8318249610 分钟前
使用Fabric自动化你的部署流程
jvm·数据库·python
njidf29 分钟前
Python日志记录(Logging)最佳实践
jvm·数据库·python
twc82929 分钟前
大模型生成 QA Pairs 提升 RAG 应用测试效率的实践
服务器·数据库·人工智能·windows·rag·大模型测试
@我漫长的孤独流浪29 分钟前
Python编程核心知识点速览
开发语言·数据库·python
宇擎智脑科技31 分钟前
A2A Python SDK 源码架构解读:一个请求是如何被处理的
人工智能·python·架构·a2a
2401_8512729931 分钟前
实战:用Python分析某电商销售数据
jvm·数据库·python
IT_陈寒32 分钟前
Redis缓存击穿:3个鲜为人知的防御策略,90%开发者都忽略了!
前端·人工智能·后端
vx_biyesheji000133 分钟前
Python 全国城市租房洞察系统 Django框架 Requests爬虫 可视化 房子 房源 大数据 大模型 计算机毕业设计源码(建议收藏)✅
爬虫·python·机器学习·django·flask·课程设计·旅游
code 小楊43 分钟前
yrb 1.5.0 正式发布:Python 极简国内下载加速与全景可视化终端体验!
开发语言·python
电商API&Tina1 小时前
【电商API接口】开发者一站式电商API接入说明
大数据·数据库·人工智能·云计算·json