241130_昇思MindSpore函数式自动微分
函数式自动微分是Mindspore学习框架所特有的,更偏向于数学计算的习惯。这里也是和pytorch差距最大的部分,具体体现在训练部分的代码,MindSpore是把各个梯度计算、损失函数计算
在这幅图中,右边这个就是函数式编程,首先先自己定义一个loss函数,最后一行使用grad(),把loss function传进去,因为传进去的是一个函数,做的是一个函数闭包,所以返回的还是一个函数。
深度学习的计算流程
首先是先正向计算,得到一个logits,然后会计算这个logits和真实的targets的误差loss,然后反向传播backwards,得到梯度,然后再送到优化器里面去更新网络权重。
MindSpore是整图计算,将模型的前向、反向、梯度更新过程全部视为一个完整的计算图,这样就有效提高了执行速度,但是带来的弊端就是代码不好书写,和pytorch差异较大,但到了2.0就得到了有效解决。
MindSpore后来使用面向对象+函数式混合使用,1-2和pytorch一样,后面3-6和原来的函数式编程一样。
1、用类构建神经网络
2、实例化Network对象
3、Network+Loss直接构造正向函数
4、函数变换,获得梯度计算(反向传播)函数
5、构造训练过程函数
6、调用函数进行训练
具体实现看以下代码:
pytorch的实现
看右边pytorch的代码就可以很直观的看出来上面说的计算流程的几步。
正向计算得到logits
python
logits=net(inputs)
计算logits和target的误差loss
python
loss=loss_fn(logits,targets)
反向传播backward
python
loss.backward()
送到优化器里去更新网络权重
python
optimizer.step()
以上是pytorch对应步骤的代码
MindSpore2.x的实现
相对来说MindSpore没有那么直观,但逻辑上都是一样的。
首先,MindSpore这边正向计算需要定义一个函数方法,里面写了loss的计算
python
def forword_fn(inputs,targets):
logits=net(inputs)
loss=loss_fn(logits,targets)
return loss
然后把整个方法作为参数传入value_and_grad方法,做一个函数闭包,然后得到一个同样的grad_fn方法(第三个参数往往是net的所有可训练参数,看上图mindspore这边第三行代码也能看出来)。
python
grad_fn=value_and_grad(forward_fn,None,optim.parameters)
训练的每个step我们也需要定义一个方法,里面把刚才得到的grad_fn方法拿过来,一次得到损失loss和梯度grads,然后直接把梯度传进优化器进行更新权重。
PYTHON
def train_step(inputs,targets):
loss,grads=grad_fn(inputs,targets)
optimizer(grads)
return loss
在实际epoch循环中,我们只需要读入data和其targets,然后直接传入单步训练的train_step就可以了
其实要说函数式微分的话,封装的第一个forward_fn已经是函数式微分了,train_step反而不像函数式微分,就是一个单纯的计算,没有涉及到函数闭包,函数套函数这样的写法,那为什么还要封装呢。
这里就涉及到MindSpore想实现加速的问题,实现整图下发,避免来回传输数据受到的带宽限制,具体实现只需要给train_step上方添加一行修饰器代码
python
@ms.jit
官方的教学示例notebook
接下来我们将定义几个变量x,y,w,b,z
x是输入,y是真实targets,w是权重,b是偏置,z是我们计算得到的label
w和b是我们要优化的东西
x置全1矩阵,y给全0矩阵,w和b随机给个初始值
然后可以构建一个loss计算的fuction
然后就是使用函数式微分,函数套函数,计算梯度
简单来说,就是你如果计算loss的哪个fuction返回多个参数的话,传入计算梯度的方法,计算结果就会出现偏差,这时候我们就要调用接口去实现stop_gradient(也不算手动实现吧。就是输出的时候嵌套一下)
我们在loss里面返回的z,现在看起来也没用,即返回了,又要排除他的影响,那为什么还要返回呢,数据又拿不出来。其实不对,数据是可以拿出来的。
调用grad_fn的时候我们就可以拿到这个数据,相当于z只是到grad_fn中转了一圈,没干什么。
接下来主要就是一个没有封装train_step的操作方法,主要逻辑上文也说过的,这里也不再赘述
候我们就可以拿到这个数据,相当于z只是到grad_fn中转了一圈,没干什么。
[外链图片转存中...(img-3PZXndA7-1732978830210)]
接下来主要就是一个没有封装train_step的操作方法,主要逻辑上文也说过的,这里也不再赘述