【AI | pytorch】torch.view_as_real的使用

torch.view_as_real使用

python 复制代码
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)

这行代码涉及 PyTorch 张量的操作,逐步解析如下:

1. xq_ * freqs_cis

  • xq_freqs_cis 是两个 PyTorch 张量。
  • 假设 xq_ 是一个复数张量,形状为 (batch_size, ..., complex_dim),其中 complex_dim = 2(因为 PyTorch 中复数可以表示为两个实数的组合)。
  • freqs_cis 是一个与 xq_ 广播兼容的张量,也可能是复数。
  • * 是逐元素乘法。对于复数张量,乘法遵循复数乘法规则。

2. torch.view_as_real(...)

  • torch.view_as_real(tensor) 将复数张量转换为实数表示。对于形状为 (batch_size, ..., complex_dim) 的复数张量,其转换结果为 (batch_size, ..., 2),其中最后一个维度 2 表示实部和虚部。

例如:

python 复制代码
z = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
real_view = torch.view_as_real(z)
# real_view 的形状为 (2, 2)
# real_view 内容为 [[1.0, 2.0], [3.0, 4.0]],分别是实部和虚部

3. .flatten(3)

  • flatten(d) 方法会将张量的第 d 维及其之后的所有维度展平为一个维度。

假设 torch.view_as_real(xq_ * freqs_cis) 的输出形状是 (batch_size, dim1, dim2, 2)

  • .flatten(3) 会将第 3 维及之后的维度展平,结果是 (batch_size, dim1, dim2 * 2)

结合起来的解释

xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 的流程:

  1. xq_freqs_cis 做逐元素乘法,得到一个复数张量。
  2. 将结果转换为实数表示,形状最后增加一维用于存储实部和虚部。
  3. 将第 3 维及之后的维度展平,以减少维度数。

最终的 xq_out 是一个纯实数张量,其维度取决于输入张量的形状和广播规则。

复数相乘

在 PyTorch 中,xq_ * freqs_cis 的逐元素乘法是基于复数张量的复数乘法规则进行的。复数张量的每个元素由实部和虚部组成,计算时遵循标准复数乘法规则:

复数乘法公式

对于两个复数 ( z_1 = a + bi ) 和 ( z_2 = c + di ),它们的乘积是:

z_1 \\cdot z_2 = (a \\cdot c - b \\cdot d) + (a \\cdot d + b \\cdot c)i

其中:

  • ( a, b ) 是 ( z_1 ) 的实部和虚部。
  • ( c, d ) 是 ( z_2 ) 的实部和虚部。

PyTorch 的复数张量

在 PyTorch 中,复数张量的每个元素是复数,由实部和虚部组成。计算时,PyTorch 会根据复数的定义自动执行上述运算。

例如:

python 复制代码
xq_ = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
freqs_cis = torch.tensor([0.5 + 0.5j, 1 + 1j], dtype=torch.complex64)

result = xq_ * freqs_cis
print(result)
# 输出: tensor([ -0.5 + 1.5j, -1.0 + 7.0j], dtype=torch.complex64)

计算细节

逐元素计算的每一对元素如下:

  1. 对于第一个元素:

    (1 + 2i) \\cdot (0.5 + 0.5i) = (1 \\cdot 0.5 - 2 \\cdot 0.5) + (1 \\cdot 0.5 + 2 \\cdot 0.5)i = -0.5 + 1.5i

  2. 对于第二个元素:

    (3 + 4i) \\cdot (1 + 1i) = (3 \\cdot 1 - 4 \\cdot 1) + (3 \\cdot 1 + 4 \\cdot 1)i = -1 + 7i

广播规则

如果 xq_freqs_cis 的形状不同,PyTorch 会尝试广播它们,使得形状匹配。广播规则包括:

  1. 右对齐形状,从后向前匹配每个维度。
  2. 如果维度不一致且其中一个维度为 1,则扩展为另一个维度的大小。
  3. 如果两者维度都不匹配且不为 1,则广播失败。

例如:

python 复制代码
xq_ = torch.tensor([[1 + 2j, 3 + 4j]], dtype=torch.complex64)  # (1, 2)
freqs_cis = torch.tensor([0.5 + 0.5j], dtype=torch.complex64)  # (1,)

result = xq_ * freqs_cis  # 广播为 (1, 2)

总结

xq_ * freqs_cis 是基于复数逐元素相乘的操作:

  1. 每个对应元素按复数乘法规则计算。
  2. 如果形状不同,会应用广播规则,使张量形状匹配。
  3. 结果是一个新的复数张量,与广播后形状一致。
相关推荐
绒绒毛毛雨33 分钟前
On the Plasticity and Stability for Post-Training Large Language Models
人工智能·机器学习·语言模型
ZTLJQ8 小时前
序列化的艺术:Python JSON处理完全解析
开发语言·python·json
H5css�海秀8 小时前
今天是自学大模型的第一天(sanjose)
后端·python·node.js·php
SuniaWang8 小时前
《Spring AI + 大模型全栈实战》学习手册系列 · 专题六:《Vue3 前端开发实战:打造企业级 RAG 问答界面》
java·前端·人工智能·spring boot·后端·spring·架构
阿贵---8 小时前
使用XGBoost赢得Kaggle比赛
jvm·数据库·python
无敌昊哥战神8 小时前
【LeetCode 257】二叉树的所有路径(回溯法/深度优先遍历)- Python/C/C++详细题解
c语言·c++·python·leetcode·深度优先
IDZSY04309 小时前
AI社交平台进阶指南:如何用AI社交提升工作学习效率
人工智能·学习
七七powerful9 小时前
运维养龙虾--AI 驱动的架构图革命:draw.io MCP 让运维画图效率提升 10 倍,使用codebuddy实战
运维·人工智能·draw.io
水星梦月9 小时前
大白话讲解AI/LLM核心概念
人工智能
温九味闻醉10 小时前
关于腾讯广告算法大赛2025项目分析1 - dataset.py
人工智能·算法·机器学习