机器学习-pytorch1(持续更新)

上一节我们学习了机器学习的线性模型和非线性模型的机器学习基础知识,这一节主要将公式变为代码

代码编写网站:https://colab.research.google.com/drive

学习课程链接:ML 2022 Spring

1、Load Data(读取数据)

这需要用到pytorch里面的两个函数Dataset和Dataloader

python 复制代码
torch.utils.data.Dataset
torch.utils.data.DataLoader

Dataset:是用来存储数据样本和期望值

python 复制代码
dataset = MyDataset(file)

Dataloader:批量对数据进行分组,启用多处理

python 复制代码
dataloader = DataLoader(dataset, batch_size, shuffle=True)

// 其中对于shuffle的取值,True表示训练,false表示测试

关于Dataset和Dataloader的关系如下:

ML 2022 Spring为图片来源

我们读取完数据,是不是想知道我们的数据长什么样子呢?(我们称数据为Tensors)

首先,它可能是一个一维数据,比如一个音频、一个温度

其次,还可能是一个二维数据,比如一张二值图像

最后,还可能是一个三维数据,比如一个彩色的图像

又有问题了,我们怎么通过编程得到我们图像的大小?

可以使用pytorch里面的shape()函数

我们怎么通过编程创造我们的数据呢?

python 复制代码
eg:
x = torch.tensor([[1,-1],[-1,1]])
x = torch.from_numpy(np.array([[1,-1],[-1,1]]))
全0或全1数据
x = torch.zeros([2,2])    # 2*2的全0数据
x = torch.ones([1,2,5])    # 1*2*5的全1数据

其次,还支持矩阵的运算

python 复制代码
Addition:z = x + y
Subtraction:z = x - y
Power:y = x.pow(2)
Summation:y = x.sum()
Mean:y = x.mean()
维度转换:x = x.transpose(dim0,dim1)
消除维度:x = x.squeeze(dim)
增加维度:x = x.unsqueeze(dim)
组合:w = torch.cat([x,y,z],dim=1)

拥有不同的数据类型:

使用.to()可以切换到不同的设备:

python 复制代码
CPU: x = x.to('cpu')
GPU: x = x.to('cuda')

这里就又涉及到如何检查你的GPU了?可以使用以下语句检查你的计算机是否有GPU:

python 复制代码
torch.cuda.is_available()

如何计算梯度?

// 注意矩阵一定要使用小数点

2、Define Neural Network(训练和测试神经网络)

python 复制代码
torch.nn.Module

线性:

非线性:

Sigmoid Activation:nn.Sigmoid()

ReLU Activation:nn.ReLU()

下面我根据所学的知识构建我自己的神经网络:

3、Loss Function(损失函数)

python 复制代码
x = torch.nn.MSELoss    # 对于回归任务
x = torch.nn.CrossEntropyLoss etc.    # 对于分类任务
loss = x(model_output,expected_value)

4、Optimization Algorithm(优化)

python 复制代码
torch.optim

这是基于梯度的优化算法,不断调整参数,减少误差

比如:随机梯度下降(SGD)

python 复制代码
torch.optim.SGD(model.parameters(), lr, momentum = 0)

* 调用optimizer.zero_grad()重置模型参数的梯度。

*调用loss.backward()反向传播预测loss的梯度。

*调用optimizer.step()调整模型参数。

5、Entire Procedure(整个程序)

python 复制代码
import torch.utils.data as data
dataset = data.Dataset(file)              # 读取数据
tr_set = DataLoader(dataset,batch_size,shuffle=True)  # 对数据集进行分组
model = MyModel().to(device)              # 建立我的模型并且选择我的设备(cpu or gpu)
criterion = nn.MSELoss()                # 建立损失函数
optimizer = torch.optim.SGD(model.parameters(),0.1)   # 建立优化
# 训练
for epoch in range(n_epochs):             # 迭代数据
  model.train()                    # 训练模型
  for x, y in tr_set:               # 迭代数据集
    optimizer.zero_grad()              # 设置梯度为0
    x, y = x.to(device),y.to(device)       # 将数据移动到设备
    pred = model(x)                # 计算输出
    loss = criterion(pred,y)            # 计算损失函数
    loss.backward()                 # 计算反向梯度
    optimizer.model()                # 优化模型
# 验证
model.eval()                      # 将模型设置为评估模式
total_loss = 0          
for x,y in dv_set:                  # 对数据集进行迭代
  x,y = x.to(device),y.to(device)          # 将数据移动到涉笔
  with torch.no_grad():                # 不可迭代的计算
    pred = model(x)                # 计算输出
    loss = criterion(pred,y)           # 计算损失函数
  total_loss += loss.cpu().item()*len(x)      # 累加损失误差
  avg_loss = total_loss / len(dv_set.dataset)   # 计算平均损失
# 测试
model.eval()                       # 将模型设置为评估模式
preds = []
for x in dv_set:                   # 对数据集进行迭代
  x = x.to(device)                  # 将数据移动到涉笔
  with torch.no_grad():                # 不可迭代的计算
    pred = model(x)                # 计算输出
    preds.append(pred.cpu())             # 收集预测

// model.eval() :更改模型的行为

// with torch.no_grad() :防止对验证/测试数据进行意外训练

当我们训练完模型,也完成了测试,为了不使模型丢失,我们需要保存模型,pytorch也为我们提供了保存模型的方法。

保存模型:torch.save(model.state_dict(),path)

下次我们使用已经训练完成的模型,或者想继续训练,我们需要读取模型。

读取模型:ckpt = torch.load(path) model.load_state_dict(ckpt)

// 这只是我根据所听的课自己写的笔记,如果有什么错误欢迎指正!!!

相关推荐
撞南墙者2 分钟前
OpenCV自学系列(1)——简介和GUI特征操作
人工智能·opencv·计算机视觉
OCR_wintone4213 分钟前
易泊车牌识别相机,助力智慧工地建设
人工智能·数码相机·ocr
王哈哈^_^25 分钟前
【数据集】【YOLO】【VOC】目标检测数据集,查找数据集,yolo目标检测算法详细实战训练步骤!
人工智能·深度学习·算法·yolo·目标检测·计算机视觉·pyqt
一者仁心31 分钟前
【AI技术】PaddleSpeech
人工智能
是瑶瑶子啦39 分钟前
【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)
论文阅读·人工智能·深度学习·视觉检测·空间变换
EasyCVR43 分钟前
萤石设备视频接入平台EasyCVR多品牌摄像机视频平台海康ehome平台(ISUP)接入EasyCVR不在线如何排查?
运维·服务器·网络·人工智能·ffmpeg·音视频
柳鲲鹏1 小时前
OpenCV视频防抖源码及编译脚本
人工智能·opencv·计算机视觉
西柚小萌新1 小时前
8.机器学习--决策树
人工智能·决策树·机器学习
向阳12181 小时前
Bert快速入门
人工智能·python·自然语言处理·bert
jndingxin1 小时前
OpenCV视觉分析之目标跟踪(8)目标跟踪函数CamShift()使用
人工智能·opencv·目标跟踪