Julia系列15:深度学习框架flux

1. 介绍

Flux对于正则化或嵌入等功能的显式API相对较少。 相反,写下数学形式将起作用 ,并且速度很快。

所有的知识和工具,从LSTM到GPU内核,都是简单的Julia代码。 如果有疑问的话,可以查看官方教程。 如果需要不同的函数块或者是功能模块,我们也可以轻松自己动手实现。

Flux适用于Julia库,包括从数据帧和图像到差分方程求解器等等内容,因此我们也可以轻松构建集成Flux模型的复杂数据处理流水线。

2. gradient用法

2.1 基本用法,传递所有参数

复制代码
julia> f(x, y) = sum((x .- y).^2);

julia> gradient(f, [2, 1], [2, 0])
([0, 2], [0, -2])

2.2 简化版,使用params传递参数

复制代码
julia> x = [2, 1];

julia> y = [2, 0];

julia> gs = gradient(params(x, y)) do
         f(x, y)
       end
Grads(...)

julia> gs[x]
2-element Array{Int64,1}:
 0
 2

julia> gs[y]
2-element Array{Int64,1}:
  0
 -2

2.3 迭代版,定义函数的函数

复制代码
# 定义函数的函数,两层函数分别对应参数和变量
julia> linear(in,out) = x -> randn(out,in)*x.+randn(out) 
julia> l1 = linear(5,3);l2=linear(5,3);
julia> model(x) = l2(σ.(l1(x)))
julia> model(rand(5))
2-element Array{Float64,1}:
  1.7485308860085003
 -0.7488549151521576

2.4 struct版,定义call

复制代码
struct Affine
  W
  b
end

Affine(in::Integer, out::Integer) =
  Affine(randn(out, in), randn(out))

# Overload call, so the object can be used as a function
(m::Affine)(x) = m.W * x .+ m.b

a = Affine(10, 5)

a(rand(10)) # => 5-element vector

2.5 类似静态图

复制代码
using Flux

layers = [Dense(10, 5, σ), Dense(5, 2), softmax]

model(x) = foldl((x, m) -> m(x), layers, init = x)

model(rand(10)) # => 2-element vector

或者用另一种方式:

复制代码
model2 = Chain(
  Dense(10, 5, σ),
  Dense(5, 2),
  softmax)

model2(rand(10)) # => 2-element vector

3. 建立模型

损失函数在Flux.Losses下

添加L2 reg:

复制代码
penalty() = sum(abs2, m.W) + sum(abs2, m.b)
loss(x, y) = logitcrossentropy(m(x), y) + penalty()

优化器在Flux.Optimise下

复制代码
using Flux.Optimise: update!

η = 0.1 # Learning Rate
for p in (W, b)
  update!(p, η * grads[p])
end

3.0 基础神经网络

手动书写模型如下

复制代码
linear(in,out) = x -> randn(out,in)*x.+randn(out) 
l1 = linear(5,3);l2=linear(5,3);
model(x) = l2(σ.(l1(x)))

使用chain将迭代调用写的更好看些,另外用Dense封装普通神经网络:

复制代码
julia> m = Chain(x -> x^2, x -> x+1);

julia> m(5) == 26
true

julia> m = Chain(Dense(10, 5), Dense(5, 2));

julia> x = rand(10);

julia> m(x) == m[2](m[1](x))
true

3.1 CNN模型

复制代码
Conv(filter, in => out, σ = identity; init = glorot_uniform,
     stride = 1, pad = 0, dilation = 1)

filter = (2,2)
in = 1
out = 16
Conv((2, 2), 1=>16, relu)

输入数据要求 WHCN (width, height, # channels, batch size)格式。其他的dropout、norm都有封装,不赘述。

使用如下方式设置inference模式和训练模式。

复制代码
testmode!(m)
trainmode!(m)

3.2 Recurrent模型

在这个模型中,每次计算不仅要给出y,还要给出中间结果h,和x一起作为下一次计算的一部分输入。手动书写模型如下:

复制代码
Wxh = randn(5, 10)
Whh = randn(5, 5)
b   = randn(5)

function rnn(h, x)
  h = tanh.(Wxh * x .+ Whh * h .+ b)
  return h, h # 这里令y就等于隐状态h
end

x = rand(10) # dummy data
h = rand(5)  # initial hidden state

h, y = rnn(h, x)

调用函数的方法如下:

复制代码
rnn2 = Flux.RNNCell(10, 5)

x = rand(10) # dummy data
h = rand(5)  # initial hidden state

h, y = rnn2(h, x)

还有一个不透露h的写法:

复制代码
x = rand(10)
h = rand(5)

m = Flux.Recur(rnn, h)

y = m(x)

或者干脆就叫RNN:

复制代码
julia> RNN(10, 5)
Recur(RNNCell(10, 5, tanh))
相关推荐
微风中的麦穗4 小时前
【MATLAB】MATLAB R2025a 详细下载安装图文指南:下一代科学计算与工程仿真平台
开发语言·matlab·开发工具·工程仿真·matlab r2025a·matlab r2025·科学计算与工程仿真
2601_949146534 小时前
C语言语音通知API示例代码:基于标准C的语音接口开发与底层调用实践
c语言·开发语言
开源技术5 小时前
Python Pillow 优化,打开和保存速度最快提高14倍
开发语言·python·pillow
学嵌入式的小杨同学5 小时前
从零打造 Linux 终端 MP3 播放器!用 C 语言实现音乐自由
linux·c语言·开发语言·前端·vscode·ci/cd·vim
mftang6 小时前
Python 字符串拼接成字节详解
开发语言·python
jasligea7 小时前
构建个人智能助手
开发语言·python·自然语言处理
kokunka7 小时前
【源码+注释】纯C++小游戏开发之射击小球游戏
开发语言·c++·游戏
AI大模型测试7 小时前
大龄程序员想转行到AI大模型,好转吗?
人工智能·深度学习·机器学习·ai·语言模型·职场和发展·大模型
ProcessOn官方账号7 小时前
程序员如何与同龄人拉开差距?这5张让你快速提升认知,打开格局!
深度学习·职场和发展·学习方法
童话名剑7 小时前
序列模型与集束搜索(吴恩达深度学习笔记)
人工智能·笔记·深度学习·机器翻译·seq2seq·集束搜索·编码-解码模型