mindspore快速入门回顾
-
导入mindspore包
-
处理数据集
- 下载mnist数据集
- 进行数据集预处理
- MnistDataset()方法
- train_dataset.get_col_names() 打印列名信息
- 使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问
-
网络构建
- mindspore.nn: 构建所有网络的基类
- 用的层有
- Flatten
- Dense
- ReLU
-
模型训练
- 正向计算
- logits:预测结果
- label:正确标签
- loss:预测损失
- 反向传播
- parameters:模型参数
- grandients:loss梯度
- 参数优化:将梯度更新到参数上。
- 步骤:
- 定义正向计算函数
- 使用value_and_grad通过函数变换获得梯度计算函数。
- 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
- 正向计算
-
保存模型
-
加载模型