Julia系列15:深度学习框架flux

1. 介绍

Flux对于正则化或嵌入等功能的显式API相对较少。 相反,写下数学形式将起作用 ,并且速度很快。

所有的知识和工具,从LSTM到GPU内核,都是简单的Julia代码。 如果有疑问的话,可以查看官方教程。 如果需要不同的函数块或者是功能模块,我们也可以轻松自己动手实现。

Flux适用于Julia库,包括从数据帧和图像到差分方程求解器等等内容,因此我们也可以轻松构建集成Flux模型的复杂数据处理流水线。

2. gradient用法

2.1 基本用法,传递所有参数

julia> f(x, y) = sum((x .- y).^2);

julia> gradient(f, [2, 1], [2, 0])
([0, 2], [0, -2])

2.2 简化版,使用params传递参数

julia> x = [2, 1];

julia> y = [2, 0];

julia> gs = gradient(params(x, y)) do
         f(x, y)
       end
Grads(...)

julia> gs[x]
2-element Array{Int64,1}:
 0
 2

julia> gs[y]
2-element Array{Int64,1}:
  0
 -2

2.3 迭代版,定义函数的函数

# 定义函数的函数,两层函数分别对应参数和变量
julia> linear(in,out) = x -> randn(out,in)*x.+randn(out) 
julia> l1 = linear(5,3);l2=linear(5,3);
julia> model(x) = l2(σ.(l1(x)))
julia> model(rand(5))
2-element Array{Float64,1}:
  1.7485308860085003
 -0.7488549151521576

2.4 struct版,定义call

struct Affine
  W
  b
end

Affine(in::Integer, out::Integer) =
  Affine(randn(out, in), randn(out))

# Overload call, so the object can be used as a function
(m::Affine)(x) = m.W * x .+ m.b

a = Affine(10, 5)

a(rand(10)) # => 5-element vector

2.5 类似静态图

using Flux

layers = [Dense(10, 5, σ), Dense(5, 2), softmax]

model(x) = foldl((x, m) -> m(x), layers, init = x)

model(rand(10)) # => 2-element vector

或者用另一种方式:

model2 = Chain(
  Dense(10, 5, σ),
  Dense(5, 2),
  softmax)

model2(rand(10)) # => 2-element vector

3. 建立模型

损失函数在Flux.Losses下

添加L2 reg:

penalty() = sum(abs2, m.W) + sum(abs2, m.b)
loss(x, y) = logitcrossentropy(m(x), y) + penalty()

优化器在Flux.Optimise下

using Flux.Optimise: update!

η = 0.1 # Learning Rate
for p in (W, b)
  update!(p, η * grads[p])
end

3.0 基础神经网络

手动书写模型如下

linear(in,out) = x -> randn(out,in)*x.+randn(out) 
l1 = linear(5,3);l2=linear(5,3);
model(x) = l2(σ.(l1(x)))

使用chain将迭代调用写的更好看些,另外用Dense封装普通神经网络:

julia> m = Chain(x -> x^2, x -> x+1);

julia> m(5) == 26
true

julia> m = Chain(Dense(10, 5), Dense(5, 2));

julia> x = rand(10);

julia> m(x) == m[2](m[1](x))
true

3.1 CNN模型

Conv(filter, in => out, σ = identity; init = glorot_uniform,
     stride = 1, pad = 0, dilation = 1)

filter = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)

输入数据要求 WHCN (width, height, # channels, batch size)格式。其他的dropout、norm都有封装,不赘述。

使用如下方式设置inference模式和训练模式。

testmode!(m)
trainmode!(m)

3.2 Recurrent模型

在这个模型中,每次计算不仅要给出y,还要给出中间结果h,和x一起作为下一次计算的一部分输入。手动书写模型如下:

Wxh = randn(5, 10)
Whh = randn(5, 5)
b   = randn(5)

function rnn(h, x)
  h = tanh.(Wxh * x .+ Whh * h .+ b)
  return h, h # 这里令y就等于隐状态h
end

x = rand(10) # dummy data
h = rand(5)  # initial hidden state

h, y = rnn(h, x)

调用函数的方法如下:

rnn2 = Flux.RNNCell(10, 5)

x = rand(10) # dummy data
h = rand(5)  # initial hidden state

h, y = rnn2(h, x)

还有一个不透露h的写法:

x = rand(10)
h = rand(5)

m = Flux.Recur(rnn, h)

y = m(x)

或者干脆就叫RNN:

julia> RNN(10, 5)
Recur(RNNCell(10, 5, tanh))
相关推荐
XiaoLeisj44 分钟前
【JavaEE初阶 — 多线程】单例模式 & 指令重排序问题
java·开发语言·java-ee
励志成为嵌入式工程师2 小时前
c语言简单编程练习9
c语言·开发语言·算法·vim
捕鲸叉2 小时前
创建线程时传递参数给线程
开发语言·c++·算法
A charmer2 小时前
【C++】vector 类深度解析:探索动态数组的奥秘
开发语言·c++·算法
Peter_chq2 小时前
【操作系统】基于环形队列的生产消费模型
linux·c语言·开发语言·c++·后端
记录成长java4 小时前
ServletContext,Cookie,HttpSession的使用
java·开发语言·servlet
前端青山4 小时前
Node.js-增强 API 安全性和性能优化
开发语言·前端·javascript·性能优化·前端框架·node.js
睡觉谁叫~~~4 小时前
一文解秘Rust如何与Java互操作
java·开发语言·后端·rust
音徽编程4 小时前
Rust异步运行时框架tokio保姆级教程
开发语言·网络·rust
观音山保我别报错4 小时前
C语言扫雷小游戏
c语言·开发语言·算法