Pytorch torch.roll函数介绍

torch.roll 是 PyTorch 中的一个函数,用于对输入张量的元素进行循环滚动操作。它可以将张量的元素在指定的维度上移动,超出边界的元素会循环回到另一侧。以下是关于 torch.roll 函数的详细介绍:

函数语法

复制代码
torch.roll(input, shifts, dims=None)

参数说明

  • input:必需参数,为输入的 PyTorch 张量,即需要进行循环滚动操作的张量。
  • shifts :表示元素滚动的位移量。可以是一个整数,此时所有指定维度都按照这个整数进行滚动;也可以是一个与 dims 长度相同的元组或列表,用于为每个指定维度分别指定滚动的位移量。正数表示元素向维度的末尾方向滚动,负数表示向维度的起始方向滚动。
  • dims:可选参数,指定要进行滚动操作的维度。可以是一个整数,表示对单一维度进行滚动;也可以是一个元组或列表,包含多个整数,用于指定对多个维度同时进行滚动。如果不指定该参数,则会将输入张量视为一维张量进行滚动。

返回值

返回一个新的张量,其元素是输入张量在指定维度上循环滚动后的结果。新张量的形状与输入张量相同。

使用示例

一维张量滚动
复制代码
import torch

# 创建一维张量
x = torch.tensor([1, 2, 3, 4, 5])
# 向右滚动 2 个位置
rolled_x = torch.roll(x, shifts=2)
print(rolled_x)  
# 输出: tensor([4, 5, 1, 2, 3])
二维张量在单个维度上滚动
复制代码
import torch

# 创建二维张量
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 在第 0 维(行)上向下滚动 1 个位置
rolled_x = torch.roll(x, shifts=1, dims=0)
print(rolled_x)
# 输出:
# tensor([[7, 8, 9],
#         [1, 2, 3],
#         [4, 5, 6]])
二维张量在多个维度上滚动
复制代码
import torch

# 创建二维张量
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 在第 0 维向下滚动 1 个位置,在第 1 维向右滚动 2 个位置
rolled_x = torch.roll(x, shifts=(1, 2), dims=(0, 1))
print(rolled_x)
# 输出:
# tensor([[8, 9, 7],
#         [2, 3, 1],
#         [5, 6, 4]])

总结

torch.roll 函数为在 PyTorch 中对张量元素进行循环滚动提供了方便的操作方式,可用于数据增强、信号处理等多种场景,通过灵活设置 shiftsdims 参数,可以实现不同维度和不同位移量的滚动操作。

相关推荐
LDG_AGI4 分钟前
【推荐系统】深度学习训练框架(九):推荐系统与LLM在Dataset、Tokenizer阶段的异同
人工智能·深度学习·算法·机器学习·推荐算法
智谱开放平台5 分钟前
让 AI 真正懂仓库:如何用 CLAUDE.md 将 Claude Code 的工作效率发挥到极致
人工智能·claude
糯米酒6 分钟前
不想使用docker部署n8n的看过来,你可以这样做
人工智能
roman_日积跬步-终至千里8 分钟前
【模式识别与机器学习(17)】聚类分析教程【2】:高级方法与离群点分析
人工智能·机器学习·支持向量机
后台开发者Ethan9 分钟前
py文件被初始化执行了2次
python
小殊小殊9 分钟前
重磅!DeepSeek发布V3.2系列模型!
论文阅读·人工智能·算法
a3158238069 分钟前
Linux部署Python Django工程和Node工程,使用宝塔面板
linux·服务器·python·django·node·strapi·宝塔面板
B站计算机毕业设计之家11 分钟前
机器学习:python智能电商推荐平台 大数据 spark(Django后端+Vue3前端+协同过滤 毕业设计/实战 源码)✅
大数据·python·spark·django·推荐算法·电商
丝斯201114 分钟前
AI学习笔记整理(19)—— AI核心技术(深度学习3)
人工智能·笔记·学习
自然语15 分钟前
深度学习时代结束了,2025年开始只剩下轮廓
数据结构·人工智能·深度学习·学习·算法