pytorch升级打怪(二)

张量

简介

张量是一种专门的数据结构,与数组和矩阵非常相似。在PyTorch中,我们使用张量来编码模型的输入和输出,以及模型的参数。

张量与NumPy的ndarrays相似,只是张量可以在GPU或其他硬件加速器上运行。事实上,张量和NumPy数组通常可以共享相同的底层内存,无需复制数据(请参阅带有NumPy的桥接)。张量也针对自动分化进行了优化(我们稍后将在Autograd部分看到更多信息)。如果您熟悉ndarrays,您就可以在家使用Tensor API。如果没有,请跟随!

初始化张量

python 复制代码
import torch
import numpy as np

#直接来自数据
#张量可以直接从数据中创建。数据类型被自动推断。

data = [[1, 2],[3, 4]]
x_data = torch.tensor(data)

#从NumPy数组
#张量可以从NumPy数组创建(反之亦然-请参阅带有NumPy的桥接)。

np_array = np.array(data)
x_np = torch.from_numpy(np_array)

#来自另一个张量:
#新张量保留参数张量的属性(形状、数据类型),除非明确覆盖。

x_ones = torch.ones_like(x_data) # retains the properties of x_data
print(f"Ones Tensor: \n {x_ones} \n")

x_rand = torch.rand_like(x_data, dtype=torch.float) # overrides the datatype of x_data
print(f"Random Tensor: \n {x_rand} \n")
#随机或恒定值:
#shape是张量维度的元组。在下面的函数中,它决定了输出张量的维度。

shape = (2,3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)

print(f"Random Tensor: \n {rand_tensor} \n")
print(f"Ones Tensor: \n {ones_tensor} \n")
print(f"Zeros Tensor: \n {zeros_tensor}")

张量的属性

张量属性描述它们的形状、数据类型以及存储它们的设备。

python 复制代码
tensor = torch.rand(3,4)

print(f"Shape of tensor: {tensor.shape}")
print(f"Datatype of tensor: {tensor.dtype}")
print(f"Device tensor is stored on: {tensor.device}")
shell 复制代码
Shape of tensor: torch.Size([3, 4])
Datatype of tensor: torch.float32
Device tensor is stored on: cpu

张量的操作

这里全面描述了100多个张量运算,包括算术、线性代数、矩阵操作(转位、索引、切片)、采样等。

这些操作中的每一个都可以在GPU上运行(通常速度比CPU快)。如果您使用的是Colab,请转到运行时>更改运行时类型>GPU来分配GPU。

默认情况下,张量是在CPU上创建的。我们需要使用.to方法(在检查GPU可用性后)明确地将张量移动到GPU。请记住,在设备之间复制大张量在时间和内存方面可能很昂贵!

python 复制代码
# We move our tensor to the GPU if available
if torch.cuda.is_available():
    tensor = tensor.to("cuda")

标准numpy样索引和切片

python 复制代码
tensor = torch.ones(4, 4)
print(f"First row: {tensor[0]}")
print(f"First column: {tensor[:, 0]}")
print(f"Last column: {tensor[..., -1]}")
tensor[:,1] = 0
print(tensor)
shell 复制代码
First row: tensor([1., 1., 1., 1.])
First column: tensor([1., 1., 1., 1.])
Last column: tensor([1., 1., 1., 1.])
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])

算术运算

python 复制代码
# This computes the matrix multiplication between two tensors. y1, y2, y3 will have the same value
# ``tensor.T`` returns the transpose of a tensor
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)

y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)


# This computes the element-wise product. z1, z2, z3 will have the same value
z1 = tensor * tensor
z2 = tensor.mul(tensor)

z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)
shell 复制代码
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])

单元素张量 如果您有一个单元素张量,例如通过将张量的所有值聚合到一个值中,您可以使用item()将其转换为Python数值:

python 复制代码
agg = tensor.sum()
agg_item = agg.item()
print(agg_item, type(agg_item))
shell 复制代码
12.0 <class 'float'>

将结果存储到操作数中的就地操作称为就地操作。它们用一个_后缀表示。例如:x.copy_(y)x.t_()将更改x。

python 复制代码
print(f"{tensor} \n")
tensor.add_(5)
print(tensor)
shell 复制代码
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])

tensor([[6., 5., 6., 6.],
        [6., 5., 6., 6.],
        [6., 5., 6., 6.],
        [6., 5., 6., 6.]])

备注:就地操作可以节省一些内存,但在计算衍生品时可能会有问题,因为历史会立即丢失。因此,不鼓励使用它们。

与NumPy的桥梁

cpu和NumPy阵列上的张量可以共享其底层内存位置,更改一个将改变另一个

张量到NumPy数组

python 复制代码
t = torch.ones(5)
print(f"t: {t}")
n = t.numpy()
print(f"n: {n}")
shell 复制代码
t: tensor([1., 1., 1., 1., 1.])
n: [1. 1. 1. 1. 1.]

张量的变化反映在NumPy数组中。

python 复制代码
t.add_(1)
print(f"t: {t}")
print(f"n: {n}")
shell 复制代码
t: tensor([2., 2., 2., 2., 2.])
n: [2. 2. 2. 2. 2.]

NumPy阵列到张量

python 复制代码
n = np.ones(5)
t = torch.from_numpy(n)

NumPy数组的变化反映在张量中。

python 复制代码
np.add(n, 1, out=n)
print(f"t: {t}")
print(f"n: {n}")
shell 复制代码
t: tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
n: [2. 2. 2. 2. 2.]
相关推荐
njxiejing16 小时前
Numpy一维、二维、三维数组切片实例
开发语言·python·numpy
兰亭妙微17 小时前
用户体验的真正边界在哪里?对的 “认知负荷” 设计思考
人工智能·ux
13631676419侯17 小时前
智慧物流与供应链追踪
人工智能·物联网
TomCode先生17 小时前
MES 离散制造核心流程详解(含关键动作、角色与异常处理)
人工智能·制造·mes
zd20057217 小时前
AI辅助数据分析和学习了没?
人工智能·学习
johnny23317 小时前
强化学习RL
人工智能
乌恩大侠17 小时前
无线网络规划与优化方式的根本性变革
人工智能·usrp
放羊郎17 小时前
基于萤火虫+Gmapping、分层+A*优化的导航方案
人工智能·slam·建图·激光slam
王哈哈^_^17 小时前
【数据集+完整源码】水稻病害数据集,yolov8水稻病害检测数据集 6715 张,目标检测水稻识别算法实战训推教程
人工智能·算法·yolo·目标检测·计算机视觉·视觉检测·毕业设计
lskisme17 小时前
springboot maven导入本地jar包
开发语言·python·pycharm