深度学习2-pyTorch学习-张量基本操作

张量的索引

如果A是一个矩阵,在matlab中我们取出来A矩阵的第m行n列的元素,可以用A(m,n)。取出来第i行的元素可以用A(i,:),取出来第i列的元素,可以用A(:,i)。而对于pytorch中的张量,它的索引略有不同:

复制代码
import torch
tensor = torch.tensor([[1,2,3],[4,5,6]],dtype = torch.float32)
print("原始张量:\n",tensor)

#1. ** 索引和切片操作**
print("获取第一行:",tensor[0])
print("获取第一行第一列的元素:",tensor[0,0])
print("获取第二列的所有元素:",tensor[:,1])
print("获取第一行的另外一种方法:",tensor[0,:]) #pytorch中tensor的索引是从0开始的

打印出来的:

复制代码
原始张量:
 tensor([[1., 2., 3.],
        [4., 5., 6.]])
获取第一行: tensor([1., 2., 3.])
获取第一行第一列的元素: tensor(1.)
获取第二列的所有元素: tensor([2., 5.])
获取第一行的另外一种方法: tensor([1., 2., 3.])

可以发现,pytorch中的张量的索引默认是从0开始的,例如,我们取矩阵的第1行,就是tensor.0,:当然也可以不写冒号,那就是tensor0。其次我们可以发现,pytorch中对张量索引的格式是tensor\[\],采用的是方括号,而不是matlab中的圆括号。

复制代码
print("张量的最大值:",tensor.max())
print("张量的最大值:",torch.max(tensor))
print("张量的最小值:",tensor.min())
print("张量的均值:",tensor.mean())
maxId = torch.argmax(tensor)
print("张量的最大值的索引",maxId)
minId = torch.argmin(tensor)
print("张量的最小值索引:",minId)

max_indices = torch.where(tensor == tensor.max())
print(f"最大值位置 (行, 列): {max_indices}")

张量形状变换

复制代码
#2. ** 形状变换操作 **
print("\n [形状变换]")
reshaped = tensor.view(3,2)  # 改变张量形状为3X2,即3行两列
print("改变形状后的张量:\n",reshaped)
flattenedTensor = tensor.flatten() # 将张量展平成一维
print("展平后的张量:\n",flattenedTensor)

tensor.view(M,N)可以将张量变为M行N列。注意张量形状改变前后的总元素数量必须一样,否则会报错,例如原来的张量是2X3,那么可以改变为3X2,但是不能变为3X4,例如以下程序:

复制代码
print("\n [形状变换]")
reshaped = tensor.view(3,4)  # 改变张量形状
print("改变形状后的张量:\n",reshaped)

输出如下:

复制代码
reshaped = tensor.view(3,4)  # 改变张量形状为3X2,即3行两列
RuntimeError: shape '[3, 4]' is invalid for input of size 6

张量的基本运算

加减乘求和,这些基本运算就是所有运算的基础,操作起来没什么说的

复制代码
#3. ** 张量的数学运算 **
print("张量的数学运算:\n")
tensorAdd = tensor + 10;
print("张量加10:\n",tensorAdd)
tensorMul = tensor * 2 # 张量乘法
print("张量乘2:\n",tensorMul)
tensorDot = tensor.sum()
print("张量元素求和:",tensorDot)

pytorch的张量,默认的保存格式都是float型,例如,上面我们原始的张量的输入是整数,但最终打印出来的数都带着个小数点,这就在提示我们,它将数据作为浮点型进行了保存。

矩阵的乘法

复制代码
#4. ** 与其他张量的操作 **
print("\n[与其他张量操作]")
tensor2 = torch.ones(3,2)
print("另一个张量:\n",tensor2)
tensorDot = torch.matmul(tensor,tensor2)
print("张量的矩阵乘法:\n",tensorDot)

这个没什么好所的,就是矩阵乘法罢了。

条件判断

复制代码
# 5.** 条件判断和筛选 **
print("\n [条件判断和筛选]")
mask = tensor > 3 # 判断一个布尔掩码
print("大于3的元素的布尔掩码:\n", mask)
filterTensor = tensor[tensor > 3] # 筛选出符合条件的元素
print("大于3的元素:\n",filterTensor)
相关推荐
通信小呆呆4 天前
当算法有了“五感”:多模态数据融合如何向人体感官协同学习?
人工智能·学习·算法·机器学习·机器人
程序猿追4 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
xiao5kou4chang6kai44 天前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
H__Rick4 天前
自动对焦学习-3
人工智能·学习·计算机视觉
Daisy Lee4 天前
量化学习-第1章-什么是量化金融
学习·金融·datawhale
renhongxia14 天前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
计算机科研狗@OUC4 天前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
Alsn864 天前
等待学习-学习目录:Docker 容器安全攻防
学习·安全·docker
YM52e4 天前
买菜计算器小应用 - HarmonyOS ArkUI 开发实战-PC版本
学习·华为·harmonyos·鸿蒙·鸿蒙系统
小雨下雨的雨4 天前
HarmonyOS ArkUI训练营入门-组件掌握系列-Animation 动画效果实现-PC版本
学习·华为·harmonyos·鸿蒙