昇思25天学习打卡营第1天|快速入门

昇思25天学习打卡营第1天|快速入门

基础介绍

本文用mindspore 的 api 快速实现一个简单的深度学习模型

处理数据集

python 复制代码
# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

下载经典MNIST数据集,目录结构如下

数据集中包含的数据列名为['image', 'label']

使用map操作对图像数据及标签进行变换处理,然后打包为大小为64的batch。

python 复制代码
def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

对数据集进行迭代访问,查看数据和标签的shape和datatype

网络构建

mindspore.nn类是构建所有网络的基类,自定义网络时可以继承nn.Cell类,并重写__init__方法和construct方法。

python 复制代码
# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()

使用nn.SequentialCell来快速组合构造一个神经网络模型。

nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

模型训练

模型完整的训练过程(step)需要实现以下三步:

  • 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  • 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  • 参数优化:将梯度更新到参数上。

写好训练函数和测试函数,评估模型的性能。

迭代多轮(epoch)数据集,每一轮遍历训练集进行训练,结束后再测试集预测。打印每一轮的loss值和预测准确率(Accuracy),loss在不断下降Accuracy在提高,这就是训练过程。

预测推理

保存训练好的模型并加载模型参数

python 复制代码
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载后的模型可以直接用于预测推理。

总结

重新开始写点学习记录,笔者作为23应届毕业生也是经历了当下就业环境的洗礼,从游戏开发转后端再到现在即将入职的银行科技岗,计算机本科生还是要多学点知识并不断充实自己,这篇文也算是一个新的起点吧,拥抱昇思,展望未来!

相关推荐
QT 小鲜肉1 小时前
【QT/C++】Qt定时器QTimer类的实现方法详解(超详细)
开发语言·数据库·c++·笔记·qt·学习
Mr.Jessy2 小时前
Web APIs 学习第五天:日期对象与DOM节点
开发语言·前端·javascript·学习·html
存在morning2 小时前
【人工智能学习笔记 三】 AI教学之前端跨栈一:React整体分层架构
笔记·学习·架构
巫婆理发2222 小时前
评估指标+数据不匹配+贝叶斯最优误差(分析方差和偏差)+迁移学习+多任务学习+端到端深度学习
深度学习·学习·迁移学习
霜绛3 小时前
C#知识补充(二)——命名空间、泛型、委托和事件
开发语言·学习·unity·c#
好望角雾眠3 小时前
第四阶段C#通讯开发-6:Socket之UDP
开发语言·笔记·学习·udp·c#
_李小白4 小时前
【OPENGL ES 3.0 学习笔记】第十七天:模型矩阵、视图矩阵与投影矩阵
笔记·学习·矩阵
淮北4944 小时前
windows11配置wsl安装ubuntu20.04
windows·学习·ubuntu·wsl
霜绛4 小时前
C#知识补充(一)——ref和out、成员属性、万物之父和装箱拆箱、抽象类和抽象方法、接口
开发语言·笔记·学习·c#
2301_796512525 小时前
Rust编程学习 - 如何利用代数类型系统做错误处理的另外一大好处是可组合性(composability)
java·学习·rust