旋转位置编码 (2)

旋转位置编码常见的torch函数

python 复制代码
# 导入必要的库
import torch

# 定义一个简单的函数来演示torch.view_as_complex和torch.view_as_real
def demo_view_as_complex_real():
    # 创建一个形状为 (2, 3, 4, 2) 的张量,最后一个维度表示复数的实部和虚部
    x = torch.randn(2, 3, 4, 2)
    
    # 使用torch.view_as_complex将最后两个维度转换为复数
    x_complex = torch.view_as_complex(x)
    print("Original tensor shape:", x.shape)
    print("After view_as_complex shape:", x_complex.shape)
    
    # 使用torch.view_as_real将复数张量转换回实部和虚部
    x_real = torch.view_as_real(x_complex)
    print("After view_as_real shape:", x_real.shape)
    
    # 检查转换是否可逆
    print("Are the original and reconstructed tensors equal?", torch.allclose(x, x_real))

demo_view_as_complex_real()
复制代码
Original tensor shape: torch.Size([2, 3, 4, 2])
After view_as_complex shape: torch.Size([2, 3, 4])
After view_as_real shape: torch.Size([2, 3, 4, 2])
Are the original and reconstructed tensors equal? True
python 复制代码
# 定义一个简单的函数来演示torch.reshape和torch.flatten
def demo_reshape_flatten():
    # 创建一个形状为 (2, 3, 4) 的张量
    x = torch.randn(2, 3, 4)
    
    # 使用torch.reshape改变张量的形状
    x_reshaped = x.reshape(2, -1)  # 将最后两个维度展平
    print("Original tensor shape:", x.shape)
    print("After reshape shape:", x_reshaped.shape)
    
    # 使用torch.flatten展平张量
    x_flattened = x.flatten(start_dim=1)  # 从第1维度开始展平
    print("After flatten shape:", x_flattened.shape)
demo_reshape_flatten()
复制代码
Original tensor shape: torch.Size([2, 3, 4])
After reshape shape: torch.Size([2, 12])
After flatten shape: torch.Size([2, 12])
python 复制代码
xq = torch.randint(0, 10, (3, 4, 6))  # 创建一个形状为 (3, 4, 6) 的张量
print("原始形状:", xq.shape)
# 执行操作
result = xq.float().reshape(*xq.shape[:-1], -1, 2)  # 转换为 float 并重塑形状
print("转换后的形状:", result.shape)
复制代码
原始形状: torch.Size([3, 4, 6])
转换后的形状: torch.Size([3, 4, 3, 2])
python 复制代码
# 定义一个简单的函数来演示torch.view
def demo_view():
    # 创建一个形状为 (2, 3, 4) 的张量
    x = torch.randn(2, 3, 4, 4)
    
    # 使用torch.view改变张量的形状
    x_viewed = x.view(2, -1)  # 将最后两个维度展平
    print("Original tensor shape:", x.shape)
    print("After view shape:", x_viewed.shape)

demo_view()
复制代码
Original tensor shape: torch.Size([2, 3, 4, 4])
After view shape: torch.Size([2, 48])
python 复制代码
# 定义一个简单的函数来演示torch.type_as
def demo_type_as():
    # 创建两个不同类型的张量
    x = torch.randn(2, 3, dtype=torch.float32)
    y = torch.randn(2, 3, dtype=torch.float64)
    
    # 使用torch.type_as将x的类型转换为与y相同
    x_converted = x.type_as(y)
    print("Original x dtype:", x.dtype)
    print("After type_as dtype:", x_converted.dtype)

demo_type_as()
复制代码
Original x dtype: torch.float32
After type_as dtype: torch.float64
python 复制代码
import torch

def unite_shape(pos_cis, x):
    print(f"输入 x 的形状: {x.shape}")
    print(f"输入 pos_cis 的形状: {pos_cis.shape}")

    ndim = x.ndim  # 获取输入张量 x 的维度数
    print(f"x 的维度数 (ndim): {ndim}")

    assert 0 <= 1 < ndim  # 确保 x 至少有两个维度
    print("断言 0 <= 1 < ndim 通过")

    assert pos_cis.shape == (x.shape[1], x.shape[-1])  # 检查 pos_cis 的形状是否匹配
    print("断言 pos_cis.shape == (x.shape[1], x.shape[-1]) 通过")

    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]  # 构建新的形状
    print(f"构建的新形状 (shape): {shape}")

    result = pos_cis.view(*shape)  # 将 pos_cis 重塑为新的形状
    print(f"重塑后的 pos_cis 形状: {result.shape}")

    return result
# 示例数据
x = torch.randn(10, 20, 30)  # 创建一个形状为 (10, 20, 30) 的张量
pos_cis = torch.randn(20, 30)  # 创建一个形状为 (20, 30) 的张量

# 调用函数
output = unite_shape(pos_cis, x)
复制代码
输入 x 的形状: torch.Size([10, 20, 30])
输入 pos_cis 的形状: torch.Size([20, 30])
x 的维度数 (ndim): 3
断言 0 <= 1 < ndim 通过
断言 pos_cis.shape == (x.shape[1], x.shape[-1]) 通过
构建的新形状 (shape): [1, 20, 30]
重塑后的 pos_cis 形状: torch.Size([1, 20, 30])

代码解释

  1. demo_view_as_complex_real 函数

    • 该函数展示了如何使用 torch.view_as_complex 将张量的最后两个维度转换为复数,并使用 torch.view_as_real 将其转换回实部和虚部。
    • 通过 torch.allclose 检查转换是否可逆。
  2. demo_reshape_flatten 函数

    • 该函数展示了如何使用 torch.reshapetorch.flatten 来改变张量的形状。
    • reshape 可以将张量重新调整为指定的形状,而 flatten 则将张量展平为指定的维度。
  3. xq 张量的操作

    • 该部分代码展示了如何将一个形状为 (3, 4, 6) 的张量转换为 float 类型,并重塑为 (3, 4, 3, 2) 的形状。
  4. demo_view 函数

    • 该函数展示了如何使用 torch.view 来改变张量的形状。view 类似于 reshape,但它要求张量在内存中是连续的。
  5. demo_type_as 函数

    • 该函数展示了如何使用 torch.type_as 将一个张量的数据类型转换为与另一个张量相同。
  6. unite_shape 函数

    • 该函数展示了如何根据输入张量 x 的形状来重塑 pos_cis 张量。
    • 通过 assert 语句确保输入张量的形状符合预期,并使用 view 方法重塑 pos_cis 的形状。

这些代码片段展示了 PyTorch 中常用的张量操作,包括形状变换、数据类型转换以及复数张量的处理。

相关推荐
Json____7 分钟前
使用python的 FastApi框架开发图书管理系统-前后端分离项目分享
开发语言·python·fastapi·图书管理系统·图书·项目练习
安思派Anspire13 分钟前
LangGraph + MCP + Ollama:构建强大代理 AI 的关键(二)
人工智能·后端·python
站大爷IP17 分钟前
Python文件与目录比较全攻略:从基础操作到性能优化
python
ahead~1 小时前
【大模型入门】访问GPT_API实战案例
人工智能·python·gpt·大语言模型llm
我爱一条柴ya2 小时前
【AI大模型】神经网络反向传播:核心原理与完整实现
人工智能·深度学习·神经网络·ai·ai编程
慕婉03072 小时前
深度学习概述
人工智能·深度学习
大模型真好玩2 小时前
准确率飙升!GraphRAG如何利用知识图谱提升RAG答案质量(额外篇)——大规模文本数据下GraphRAG实战
人工智能·python·mcp
19892 小时前
【零基础学AI】第30讲:生成对抗网络(GAN)实战 - 手写数字生成
人工智能·python·深度学习·神经网络·机器学习·生成对抗网络·近邻算法
applebomb2 小时前
没合适的组合wheel包,就自行编译flash_attn吧
python·ubuntu·attention·flash
神经星星2 小时前
新加坡国立大学基于多维度EHR数据实现细粒度患者队列建模,住院时间预测准确率提升16.3%
人工智能·深度学习·机器学习