昇思25天学习打卡营第1天|快速入门

昇思25天学习打卡营第1天|快速入门

基础介绍

本文用mindspore 的 api 快速实现一个简单的深度学习模型

处理数据集

python 复制代码
# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

下载经典MNIST数据集,目录结构如下

数据集中包含的数据列名为['image', 'label']

使用map操作对图像数据及标签进行变换处理,然后打包为大小为64的batch。

python 复制代码
def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

对数据集进行迭代访问,查看数据和标签的shape和datatype

网络构建

mindspore.nn类是构建所有网络的基类,自定义网络时可以继承nn.Cell类,并重写__init__方法和construct方法。

python 复制代码
# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()

使用nn.SequentialCell来快速组合构造一个神经网络模型。

nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

模型训练

模型完整的训练过程(step)需要实现以下三步:

  • 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  • 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  • 参数优化:将梯度更新到参数上。

写好训练函数和测试函数,评估模型的性能。

迭代多轮(epoch)数据集,每一轮遍历训练集进行训练,结束后再测试集预测。打印每一轮的loss值和预测准确率(Accuracy),loss在不断下降Accuracy在提高,这就是训练过程。

预测推理

保存训练好的模型并加载模型参数

python 复制代码
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载后的模型可以直接用于预测推理。

总结

重新开始写点学习记录,笔者作为23应届毕业生也是经历了当下就业环境的洗礼,从游戏开发转后端再到现在即将入职的银行科技岗,计算机本科生还是要多学点知识并不断充实自己,这篇文也算是一个新的起点吧,拥抱昇思,展望未来!

相关推荐
茯苓gao2 小时前
STM32G4 速度环开环,电流环闭环 IF模式建模
笔记·stm32·单片机·嵌入式硬件·学习
是誰萆微了承諾3 小时前
【golang学习笔记 gin 】1.2 redis 的使用
笔记·学习·golang
DKPT3 小时前
Java内存区域与内存溢出
java·开发语言·jvm·笔记·学习
aaaweiaaaaaa3 小时前
HTML和CSS学习
前端·css·学习·html
看海天一色听风起雨落4 小时前
Python学习之装饰器
开发语言·python·学习
speop5 小时前
llm的一点学习笔记
笔记·学习
非凡ghost6 小时前
FxSound:提升音频体验,让音乐更动听
前端·学习·音视频·生活·软件需求
ue星空6 小时前
月2期学习笔记
学习·游戏·ue5
萧邀人6 小时前
第二课、熟悉Cocos Creator 编辑器界面
学习
m0_571372827 小时前
嵌入式ARM架构学习2——汇编
arm开发·学习