241130_昇思MindSpore函数式自动微分

241130_昇思MindSpore函数式自动微分

函数式自动微分是Mindspore学习框架所特有的,更偏向于数学计算的习惯。这里也是和pytorch差距最大的部分,具体体现在训练部分的代码,MindSpore是把各个梯度计算、损失函数计算

在这幅图中,右边这个就是函数式编程,首先先自己定义一个loss函数,最后一行使用grad(),把loss function传进去,因为传进去的是一个函数,做的是一个函数闭包,所以返回的还是一个函数。

深度学习的计算流程

首先是先正向计算,得到一个logits,然后会计算这个logits和真实的targets的误差loss,然后反向传播backwards,得到梯度,然后再送到优化器里面去更新网络权重。

MindSpore是整图计算,将模型的前向、反向、梯度更新过程全部视为一个完整的计算图,这样就有效提高了执行速度,但是带来的弊端就是代码不好书写,和pytorch差异较大,但到了2.0就得到了有效解决。

MindSpore后来使用面向对象+函数式混合使用,1-2和pytorch一样,后面3-6和原来的函数式编程一样。

1、用类构建神经网络

2、实例化Network对象

3、Network+Loss直接构造正向函数

4、函数变换,获得梯度计算(反向传播)函数

5、构造训练过程函数

6、调用函数进行训练

具体实现看以下代码:

pytorch的实现

看右边pytorch的代码就可以很直观的看出来上面说的计算流程的几步。

正向计算得到logits

python 复制代码
logits=net(inputs)

计算logits和target的误差loss

python 复制代码
loss=loss_fn(logits,targets)

反向传播backward

python 复制代码
loss.backward()

送到优化器里去更新网络权重

python 复制代码
optimizer.step()

以上是pytorch对应步骤的代码

MindSpore2.x的实现

相对来说MindSpore没有那么直观,但逻辑上都是一样的。

首先,MindSpore这边正向计算需要定义一个函数方法,里面写了loss的计算

python 复制代码
def forword_fn(inputs,targets):
    logits=net(inputs)
    loss=loss_fn(logits,targets)
    return loss

然后把整个方法作为参数传入value_and_grad方法,做一个函数闭包,然后得到一个同样的grad_fn方法(第三个参数往往是net的所有可训练参数,看上图mindspore这边第三行代码也能看出来)。

python 复制代码
grad_fn=value_and_grad(forward_fn,None,optim.parameters)

训练的每个step我们也需要定义一个方法,里面把刚才得到的grad_fn方法拿过来,一次得到损失loss和梯度grads,然后直接把梯度传进优化器进行更新权重。

PYTHON 复制代码
def train_step(inputs,targets):
    loss,grads=grad_fn(inputs,targets)
    optimizer(grads)
    return loss

在实际epoch循环中,我们只需要读入data和其targets,然后直接传入单步训练的train_step就可以了

其实要说函数式微分的话,封装的第一个forward_fn已经是函数式微分了,train_step反而不像函数式微分,就是一个单纯的计算,没有涉及到函数闭包,函数套函数这样的写法,那为什么还要封装呢。

这里就涉及到MindSpore想实现加速的问题,实现整图下发,避免来回传输数据受到的带宽限制,具体实现只需要给train_step上方添加一行修饰器代码

python 复制代码
@ms.jit

官方的教学示例notebook

接下来我们将定义几个变量x,y,w,b,z

x是输入,y是真实targets,w是权重,b是偏置,z是我们计算得到的label

w和b是我们要优化的东西

x置全1矩阵,y给全0矩阵,w和b随机给个初始值

然后可以构建一个loss计算的fuction

然后就是使用函数式微分,函数套函数,计算梯度

简单来说,就是你如果计算loss的哪个fuction返回多个参数的话,传入计算梯度的方法,计算结果就会出现偏差,这时候我们就要调用接口去实现stop_gradient(也不算手动实现吧。就是输出的时候嵌套一下)

我们在loss里面返回的z,现在看起来也没用,即返回了,又要排除他的影响,那为什么还要返回呢,数据又拿不出来。其实不对,数据是可以拿出来的。

调用grad_fn的时候我们就可以拿到这个数据,相当于z只是到grad_fn中转了一圈,没干什么。

接下来主要就是一个没有封装train_step的操作方法,主要逻辑上文也说过的,这里也不再赘述

候我们就可以拿到这个数据,相当于z只是到grad_fn中转了一圈,没干什么。

[外链图片转存中...(img-3PZXndA7-1732978830210)]

接下来主要就是一个没有封装train_step的操作方法,主要逻辑上文也说过的,这里也不再赘述

相关推荐
Ven%24 分钟前
如何让后台运行llamafactory-cli webui 即使关掉了ssh远程连接 也在运行
运维·人工智能·chrome·python·ssh·aigc
Jeo_dmy29 分钟前
(七)人工智能进阶之人脸识别:从刷脸支付到智能安防的奥秘,小白都可以入手的MTCNN+Arcface网络
人工智能·计算机视觉·人脸识别·猪脸识别
睡觉狂魔er2 小时前
自动驾驶控制与规划——Project 5: Lattice Planner
人工智能·机器学习·自动驾驶
xm一点不soso2 小时前
ROS2+OpenCV综合应用--11. AprilTag标签码跟随
人工智能·opencv·计算机视觉
云卓SKYDROID3 小时前
无人机+Ai应用场景!
人工智能·无人机·科普·高科技·云卓科技
是十一月末3 小时前
机器学习之过采样和下采样调整不均衡样本的逻辑回归模型
人工智能·python·算法·机器学习·逻辑回归
小禾家的3 小时前
.NET AI 开发人员库 --AI Dev Gallery简单示例--问答机器人
人工智能·c#·.net
生信碱移3 小时前
万字长文:机器学习的数学基础(易读)
大数据·人工智能·深度学习·线性代数·算法·数学建模·数据分析
KeyPan4 小时前
【机器学习:四、多输入变量的回归问题】
人工智能·数码相机·算法·机器学习·计算机视觉·数据挖掘·回归
码力全開4 小时前
C 语言奇幻之旅 - 第14篇:C 语言高级主题
服务器·c语言·开发语言·人工智能·算法