PyTorch快速入门

文章目录

前言

你好,我是醉墨居士,今天分享一下PyTorch的基本使用的快速入门教程,希望能够帮助各位快速掌握PyTorch的使用

简介

PyTorch 是一个开源的深度学习框架,由 Facebook 的人工智能研究团队(FAIR)开发。它在学术界和工业界都被广泛使用,为深度学习的研究和应用提供了强大的支持

软件包导入

本教程中我们使用torch和numpy这两个包

如果是ubuntu系统,可以看这个教程配置环境:https://blog.csdn.net/qq_67733273/article/details/144787375

py 复制代码
import torch
import numpy as np

创建张量

  • form list
py 复制代码
print(torch.tensor([1, 2]), torch.tensor([[1, 2], [3, 4]]))
bash 复制代码
tensor([1, 2]) tensor([[1, 2],
        [3, 4]])
  • form numpy
py 复制代码
print(torch.from_numpy(np.array([1, 2])))
bash 复制代码
tensor([1, 2])
  • 全0填充
py 复制代码
print(torch.zeros(2, 3))
bash 复制代码
tensor([[0., 0., 0.],
        [0., 0., 0.]])
  • 全1填充
py 复制代码
print(torch.ones(2, 3))
bash 复制代码
tensor([[1., 1., 1.],
        [1., 1., 1.]])
  • 自定义数值全部填充
py 复制代码
print(torch.full([2, 3], 6))
bash 复制代码
tensor([[6, 6, 6],
        [6, 6, 6]])
  • 对角矩阵
py 复制代码
print(torch.eye(3, 3))
bash 复制代码
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
  • 未初始化
py 复制代码
print(torch.empty(2, 3))
bash 复制代码
tensor([[5.1981e-06, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.5134e-21, 4.2892e-41]])
  • 标准正态分布
py 复制代码
print(torch.randn(2, 3))
bash 复制代码
tensor([[ 0.0711,  0.2728, -0.4199],
        [ 1.0625,  0.9611, -2.0447]])
  • 0 - 1浮点均匀分布
py 复制代码
print(torch.rand(2, 3))
bash 复制代码
tensor([[0.2974, 0.6033, 0.7126],
        [0.1599, 0.5974, 0.8983]])
  • 1 - 9整数均匀分布
py 复制代码
print(torch.randint(1, 10, [2, 3]))
bash 复制代码
tensor([[1, 2, 8],
        [9, 3, 4]])
  • 递增数列, 0 ~ 9,步长是2
py 复制代码
print(torch.arange(0, 10, 2))
bash 复制代码
tensor([0, 2, 4, 6, 8])
  • 等差数列,0 ~ 10, 共计4个数
py 复制代码
print(torch.linspace(0, 10, 4))
bash 复制代码
tensor([ 0.0000,  3.3333,  6.6667, 10.0000])
  • 指定类型创建
py 复制代码
print(torch.FloatTensor([1, 2]), torch.LongTensor([1, 2]))
bash 复制代码
tensor([1., 2.]) tensor([1, 2])

类型操作

  • 获取数据类型
py 复制代码
print(torch.randn(2, 3).dtype)
bash 复制代码
torch.float32
  • 转换数据类型
py 复制代码
print(torch.randn(2, 3).to(torch.float64).dtype)
bash 复制代码
torch.float64
  • 张量形状
py 复制代码
print(torch.randn(2, 3).shape)
bash 复制代码
torch.Size([2, 3])

索引

  • 测试数据

4张图片,3通道,28*28像素

py 复制代码
a = torch.randn(4, 3, 28, 28)
print(a.shape)
复制代码
torch.Size([4, 3, 28, 28])

直接索引

  • 查看第0张图片
py 复制代码
print(a[0].shape)
复制代码
torch.Size([3, 28, 28])
  • 查看第0张图片的第0个通道
py 复制代码
print(a[0, 0].shape)
bash 复制代码
torch.Size([28, 28])
  • 查看第0张图片的第0个通道的第2行
py 复制代码
print(a[0, 0, 2].shape)
bash 复制代码
torch.Size([28])
  • 查看第0张图片的第0个通道的第2行第4列
py 复制代码
print(a[0, 0, 2, 4].shape)
bash 复制代码
torch.Size([])

切片索引

  • 查看0 - 1张图片, 其中 :2 表示 0 - 1
py 复制代码
print(a[:2].shape)
bash 复制代码
torch.Size([2, 3, 28, 28])
  • 查看0 - 1张图片的0 - 1通道
py 复制代码
print(a[:2, :2].shape)
bash 复制代码
torch.Size([2, 2, 28, 28])
  • 查看0 - 1张图片的所有通道的倒数5 - 正数24行, 其中 : 表示所有
py 复制代码
print(a[:2, :, -5:25].shape)
bash 复制代码
torch.Size([2, 3, 2, 28])
  • 有间隔的索引,其中 ::2 表示每次间隔2个取值
py 复制代码
print(a[:, :, :, ::2].shape)
bash 复制代码
torch.Size([4, 3, 28, 14])
  • 查看所有图片的第0 - 1列, 其中 ... 表示省略
py 复制代码
print(a[..., :2].shape)
bash 复制代码
torch.Size([4, 3, 28, 2])
  • 查看所有图片的第0 - 1行
py 复制代码
print(a[..., :2, :].shape)
bash 复制代码
torch.Size([4, 3, 2, 28])
  • 查看第1张图片
py 复制代码
print(a[1, ...].shape)
bash 复制代码
torch.Size([3, 28, 28])

维度变换

  • 测试数据

4张图片, 单通道, 28*28像素

py 复制代码
a = torch.randn(4, 1, 28, 28)
print(a.shape)
bash 复制代码
torch.Size([4, 1, 28, 28])

  • 转换为4, 784维度
py 复制代码
print(a.reshape(4, 784).shape)
bash 复制代码
torch.Size([4, 784])
  • 转换成4, 28, 28维度
py 复制代码
print(a.reshape(4, 28, 28).shape)
bash 复制代码
torch.Size([4, 28, 28])
  • view的基本用法和reshape一致
py 复制代码
print(a.view(4, 784).shape)
bash 复制代码
torch.Size([4, 784])

增加维度

  • 测试数据

2*2的tensor

py 复制代码
a = torch.randn(2, 2)
print(a.shape)
bash 复制代码
torch.Size([2, 2])

  • 插入维度到第0维
py 复制代码
print(a.unsqueeze(0).shape)
bash 复制代码
torch.Size([1, 2, 2])
  • 插入维度在倒数第1维
py 复制代码
print(a.unsqueeze(-1).shape)
bash 复制代码
torch.Size([2, 2, 1])

删除维度

  • 测试数据

122*1的tensor

py 复制代码
a = torch.randn(1, 2, 2, 1)
print(a.shape)
bash 复制代码
torch.Size([1, 2, 2, 1])

  • 删除第0维
py 复制代码
print(a.squeeze(0).shape)
bash 复制代码
torch.Size([2, 2, 1])
  • 删除倒数第1维
py 复制代码
print(a.squeeze(-1).shape)
bash 复制代码
torch.Size([1, 2, 2])
  • 删除所有为1的维度
py 复制代码
print(a.squeeze().shape)
bash 复制代码
torch.Size([2, 2])

维度重复

  • 分别在第1个维度和第2个维度重复2和3次
py 复制代码
print(torch.randn(2, 2).repeat(2, 3).shape)
bash 复制代码
torch.Size([4, 6])

维度交换

  • t转置,只能操作2维tensor
py 复制代码
print(torch.randn(1, 2).t().shape)
bash 复制代码
torch.Size([2, 1])
  • 维度交换,指定交换的维度,只能两两交换
py 复制代码
print(torch.randn(1, 2, 3).transpose(0, 1).shape)
bash 复制代码
  • 维度交换,输入维度的顺序
py 复制代码
print(torch.rand(1, 2, 3).permute(2, 1, 0).shape)
bash 复制代码
torch.Size([2, 1, 3])

broadcast

  • 测试数据

3个张量

py 复制代码
a = torch.randn(2, 3)
b = torch.randn(1, 3)
c = torch.randn(1)
print(a.shape)
print(b.shape)
print(c.shape)
bash 复制代码
torch.Size([2, 3])
torch.Size([1, 3])
torch.Size([1])

  • 自动boradcast
py 复制代码
print((a + b).shape)
print((a + c).shape)
bash 复制代码
torch.Size([2, 3])
torch.Size([2, 3])
  • 手动boradcast
py 复制代码
print(b.expand_as(a).shape)
print(c.expand_as(a).shape)
bash 复制代码
torch.Size([2, 3])
torch.Size([2, 3])

合并张量

  • 拼接,dim=0是指定要拼接的维度
py 复制代码
a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
print(a.shape)
print(b.shape)
print(torch.cat([a, b], dim=0).shape)
bash 复制代码
torch.Size([4, 32, 8])
torch.Size([5, 32, 8])
torch.Size([9, 32, 8])
  • 组合,创建一个新的维度,用于区分组合后的两个tensor
py 复制代码
a = torch.rand(4, 32, 8)
b = torch.rand(4, 32, 8)
print(a.shape)
print(b.shape)
print(torch.stack([a, b], dim=0).shape)
bash 复制代码
torch.Size([4, 32, 8])
torch.Size([4, 32, 8])
torch.Size([2, 4, 32, 8])

拆分张量

  • 测试数据

1个张量

py 复制代码
a = torch.rand(4, 32, 8)
print(a.shape)
bash 复制代码
torch.Size([4, 32, 8])

  • split拆分,在0维度上拆分,每2个元素1拆
py 复制代码
b, c = a.split(2, dim=0)
print(b.shape)
print(c.shape)
bash 复制代码
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])
  • split拆分,在0维度上拆分,拆分后长度分别为1,2,1
py 复制代码
b, c, d = a.split([1, 2, 1], dim=0)
print(b.shape)
print(c.shape)
print(d.shape)
bash 复制代码
torch.Size([1, 32, 8])
torch.Size([2, 32, 8])
torch.Size([1, 32, 8])
  • chunk拆分,在0维度上拆分,拆成2个
py 复制代码
b, c = a.chunk(2, dim=0)
print(b.shape)
print(c.shape)
bash 复制代码
torch.Size([2, 32, 8])
torch.Size([2, 32, 8])

运算

测试数据

2个张量

py 复制代码
a = torch.FloatTensor([[0, 1.1, 2.2], [3.3, 4.4, 5.5]])
b = torch.FloatTensor([0, 1.1, 2.2])
print(a.shape)
print(b.shape)
bash 复制代码
torch.Size([2, 3])
torch.Size([3])
  • 矩阵乘法
py 复制代码
print(a @ b)
print(a.matmul(b))
bash 复制代码
tensor([ 6.0500, 16.9400])
tensor([ 6.0500, 16.9400])
  • 四则运算,因为两个tensor的维度不同,会进行自动boradcast,然后计算
py 复制代码
print(a + b)
print(a - b)
print(a * b)
print(a / b)
bash 复制代码
tensor([[0.0000, 2.2000, 4.4000],
        [3.3000, 5.5000, 7.7000]])
tensor([[0.0000, 0.0000, 0.0000],
        [3.3000, 3.3000, 3.3000]])
tensor([[ 0.0000,  1.2100,  4.8400],
        [ 0.0000,  4.8400, 12.1000]])
tensor([[   nan, 1.0000, 1.0000],
        [   inf, 4.0000, 2.5000]])
  • 求指数
py 复制代码
print(a**2)
bash 复制代码
tensor([[ 0.0000,  1.2100,  4.8400],
        [10.8900, 19.3600, 30.2500]])
  • 开根号
py 复制代码
print(a**0.5)
bash 复制代码
tensor([[0.0000, 1.0488, 1.4832],
        [1.8166, 2.0976, 2.3452]])
  • 求e的n次方
py 复制代码
print(a.exp())
bash 复制代码
tensor([[  1.0000,   3.0042,   9.0250],
        [ 27.1126,  81.4509, 244.6919]])
  • 以e为底,求对数
py 复制代码
print(a.log())
bash 复制代码
tensor([[  -inf, 0.0953, 0.7885],
        [1.1939, 1.4816, 1.7047]])
  • 以2为底,求对数
py 复制代码
print(a.log2())
bash 复制代码
tensor([[  -inf, 0.1375, 1.1375],
        [1.7225, 2.1375, 2.4594]])
  • 大于
py 复制代码
print(a > b)
bash 复制代码
tensor([[False, False, False],
        [ True,  True,  True]])
  • 大于等于
py 复制代码
print(a >= b)
bash 复制代码
tensor([[True, True, True],
        [True, True, True]])
  • 小于
py 复制代码
print(a < b)
bash 复制代码
tensor([[False, False, False],
        [False, False, False]])
  • 小于等于
py 复制代码
print(a <= b)
bash 复制代码
tensor([[ True,  True,  True],
        [False, False, False]])
  • 等于
py 复制代码
print(a == b)
bash 复制代码
tensor([[ True,  True,  True],
        [False, False, False]])
  • 不等于
py 复制代码
print(a != b)
bash 复制代码
tensor([[False, False, False],
        [ True,  True,  True]])
  • 限制数据的上下限
py 复制代码
print(a.clamp(2, 4))
bash 复制代码
tensor([[2.0000, 2.0000, 2.2000],
        [3.3000, 4.0000, 4.0000]])
  • 向下取整
py 复制代码
print(a.floor())
bash 复制代码
tensor([[0., 1., 2.],
        [3., 4., 5.]])
  • 向上取整
py 复制代码
print(a.ceil())
bash 复制代码
tensor([[0., 2., 3.],
        [4., 5., 6.]])
  • 四舍五入
py 复制代码
print(a.round())
bash 复制代码
tensor([[0., 1., 2.],
        [3., 4., 6.]])
  • 求最小
py 复制代码
print(a.min())
bash 复制代码
tensor(0.)
  • 求最大
py 复制代码
print(a.max())
bash 复制代码
tensor(5.5000)
  • 求平均
py 复制代码
print(a.mean())
bash 复制代码
tensor(2.7500)
  • 求积
py 复制代码
print(a.prod())
bash 复制代码
tensor(0.)
  • 求和
py 复制代码
print(a.sum())
bash 复制代码
tensor(16.5000)
  • 求最大值下标
py 复制代码
print(a.argmax())
bash 复制代码
tensor(5)
  • 求最小值下标
py 复制代码
print(a.argmin())
bash 复制代码
tensor(0)
  • 分维度求最大
py 复制代码
print(a.max(dim=0))
bash 复制代码
torch.return_types.max(
values=tensor([3.3000, 4.4000, 5.5000]),
indices=tensor([1, 1, 1]))
  • 分维度求最大值下标
py 复制代码
print(a.argmax(dim=0))
bash 复制代码
tensor([1, 1, 1])
  • 求1范数
py 复制代码
print(a.norm(1))
print(a.norm(1, dim=0))
bash 复制代码
tensor(16.5000)
tensor([3.3000, 5.5000, 7.7000])
  • 求2范数
py 复制代码
print(a.norm(2))
print(a.norm(2, dim=0))
bash 复制代码
tensor(8.1578)
tensor([3.3000, 4.5354, 5.9237])
  • 求前2个最小值
py 复制代码
print(a.topk(2, dim=1, largest=False))
bash 复制代码
torch.return_types.topk(
values=tensor([[0.0000, 1.1000],
        [3.3000, 4.4000]]),
indices=tensor([[0, 1],
        [0, 1]]))
  • 求第2个小的值
py 复制代码
a.kthvalue(2, dim=1)
bash 复制代码
torch.return_types.kthvalue(
values=tensor([1.1000, 4.4000]),
indices=tensor([1, 1]))

最后

感谢您的阅读,我是醉墨居士,希望对你有所帮助,也希望你能够使用PyTorch为你的未来开创更多可能性

相关推荐
小袁拒绝摆烂6 分钟前
OpenCV-python灰度变化和直方图修正类型
python·opencv·计算机视觉
gogoMark3 小时前
口播视频怎么剪!利用AI提高口播视频剪辑效率并增强”网感”
人工智能·音视频
Dxy12393102163 小时前
Python 条件语句详解
开发语言·python
龙泉寺天下行走3 小时前
Python 翻译词典小程序
python·oracle·小程序
2201_754918413 小时前
OpenCV 特征检测全面解析与实战应用
人工智能·opencv·计算机视觉
践行见远4 小时前
django之视图
python·django·drf
love530love5 小时前
Windows避坑部署CosyVoice多语言大语言模型
人工智能·windows·python·语言模型·自然语言处理·pycharm
勇闯逆流河5 小时前
【数据结构】堆
c语言·数据结构·算法
985小水博一枚呀5 小时前
【AI大模型学习路线】第二阶段之RAG基础与架构——第七章(【项目实战】基于RAG的PDF文档助手)技术方案与架构设计?
人工智能·学习·语言模型·架构·大模型
pystraf5 小时前
LG P9844 [ICPC 2021 Nanjing R] Paimon Segment Tree Solution
数据结构·c++·算法·线段树·洛谷