SpikingJelly笔记之延迟编码

文章目录


前言

记录SpikingJelly中延迟编码的使用方法,以及自定义延时函数。


一、延迟编码的原理

首次脉冲时间编码(Time-to-First-Spike,TTFS)

基于时间的编码方式,输入强度越大,发放时间越早

存在最晚发放时间: t m a x = T t_{max} = T tmax=T

1、线性延迟

python 复制代码
encoder = encoding.LatencyEncoder(T, enc_function='linear')

t f = ( t m a x − 1 ) ( 1 − x ) t_{f} = (t_{max}-1)(1-x) tf=(tmax−1)(1−x)

发放延时 发放时刻

2、对数延迟

python 复制代码
encoder = encoding.LatencyEncoder(T, enc_function='log')

t f = ( t m a x − 1 ) − ln ⁡ ( α x + 1 ) t_{f} = (t_{max}-1)-\ln(\alpha x+1) tf=(tmax−1)−ln(αx+1)

α = exp ⁡ ( t m a x − 1 ) − 1 \alpha = \exp(t_{max}-1)-1 α=exp(tmax−1)−1

t m a x t_{max} tmax不宜过大,防止 α \alpha α溢出

发放延时 发放时刻

二、发放延时转化为脉冲序列

python 复制代码
out_spike = F.one_hot(t_f, T).to(x)

(1)根据输入计算发放时间[0, T-1]

(2)独热编码转为脉冲序列

(3)通过参数设置避免出现全0(不发放)

三、SpikingJelly中的延时编码

python 复制代码
from spikingjelly.activation_based import encoding
encoder = encoding.LatencyEncoder(T=T, enc_function='linear')
# encoder = encoding.LatencyEncoder(T=T, enc_function='log')
# 输出脉冲序列,T:时间步长,w:图像宽度,h:图像高度
out_spike = torch.zeros((T, w, h), dtype=torch.bool)
# 按时间步单步运算
for t in range(T):
	out_spike[t] = encoder(x) # x需要归一化[0,1]
encoder.reset() # 有状态编码器,需要复位

四、自定义延时函数

SpikingJelly中提供了线性和对数两种延迟模式

可以根据需要进行修改,自定义延时函数

在此定义一个指数延时函数: y = − 2.5 + 20.0 ∗ exp ⁡ ( − x 0.5 ) y=-2.5+20.0*\exp(-\frac{x}{0.5}) y=−2.5+20.0∗exp(−0.5x)

当输入过小时允许不发放脉冲

python 复制代码
####################指数延迟####################
class ExpLatencyEncoder(encoding.StatefulEncoder):
    def __init__(self, T: int, step_mode='s'):
        super().__init__(T, step_mode)
    def single_step_encode(self, x: torch.Tensor):
        # 修改发放时间-输入方程
        t_f = -2.5 + 20.0 * torch.exp(-x/0.5)
        # 控制输出范围[0, T]
        t_f = torch.clamp(t_f, 0, T).round().long()
        self.spike = F.one_hot(t_f, num_classes=self.T+1).to(x)
        # [*, T+1] -> [T+1, *]
        d_seq = list(range(self.spike.ndim - 1))
        d_seq.insert(0, self.spike.ndim - 1)
        self.spike = self.spike.permute(d_seq)
        # 截取前T时间步的脉冲
        self.spike = self.spike[:T]
发放延时 发放时刻

五、Lena图像的延迟编码

1、原始图像

读取、展示原始图像

python 复制代码
####################延时编码####################
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from spikingjelly.activation_based import encoding
from spikingjelly import visualizing
####################读取图像####################
img = np.array(Image.open('../dataset/lena.bmp')) / 255
x = torch.from_numpy(img)
w, h = x.shape
plt.figure()
plt.imshow(x, cmap='gray')
plt.axis('off')

2、图像编码

取十个神经元展示编码结果

python 复制代码
####################延迟编码####################
T = 10 # 时间步长
encoder = encoding.LatencyEncoder(T=T, enc_function='linear')
# encoder = encoding.LatencyEncoder(T=T, enc_function='log')
# encoder = ExpLatencyEncoder(T=T)
out_spike = torch.zeros((T, w, h), dtype=torch.bool)
for t in range(T):
    out_spike[t] = encoder(x)
encoder.reset()
# 取十个神经元发放情况
figsize, dpi = (6, 4), 100
visualizing.plot_1d_spikes(spikes=out_spike.numpy()[:,:500:50,250],
                            title='Out Spikes',
                            xlabel='Simulating Step',
                            ylabel='Neuron Index',
                            figsize=figsize,
                            dpi=dpi)
plt.show()
线性延迟 对数延迟 指数延迟(自定义)

对数延时可以将过大和过小的输出区分开,而线性延迟表现得更加均匀一些

3、图像还原

越早发放对应的值越大,因为均采用线性还原,还原质量不代表编码效果

python 复制代码
img_code = torch.zeros(w, h)
for i in range(T):
    img_code += out_spike[i] * (T - i - 1)
plt.imshow(img_code, cmap='gray')
plt.axis('off')
线性延迟 对数延迟 指数延迟(自定义)

总结

延迟编码将输入转化为脉冲序列,是一种基于时间编码的方式

较大的输入对应于较早的输出,有且仅有一次发放

延迟编码的输入需要归一化[0,1]

设置延时函数时需要保证最大延时 t m a x t_{max} tmax小于时间步长 T T T

参考:SpikingJelly:编码器

相关推荐
孤单网愈云10 小时前
如何理解tensor中张量的维度
pytorch·python·深度学习
爱喝热水的呀哈喽10 小时前
pde_accuracy阅读【1】
人工智能·pytorch·深度学习
起名字真南10 小时前
丹摩 | 基于PyTorch的CIFAR-10图像分类实现
人工智能·pytorch·分类
土豆炒马铃薯。13 小时前
【深度学习】Pytorch 1.x 安装命令
linux·人工智能·pytorch·深度学习·ubuntu·centos
土豆炒马铃薯。14 小时前
CUDA,PyTorch,GCC 之间的版本关系
linux·c++·人工智能·pytorch·python·深度学习·opencv
陈苏同学1 天前
机器翻译 & 数据集 (NLP基础 - 预处理 → tokenize → 词表 → 截断/填充 → 迭代器) + 代码实现 —— 笔记3.9《动手学深度学习》
人工智能·pytorch·笔记·python·深度学习·自然语言处理·机器翻译
取个名字真难呐1 天前
6、PyTorch中搭建分类网络实例
人工智能·pytorch·分类
猎嘤一号2 天前
个人笔记本安装CUDA并配合Pytorch使用NVIDIA GPU训练神经网络的计算以及CPUvsGPU计算时间的测试代码
人工智能·pytorch·神经网络
z千鑫2 天前
【人工智能】深入理解PyTorch:从0开始完整教程!全文注解
人工智能·pytorch·python·gpt·深度学习·ai编程
爱喝热水的呀哈喽2 天前
torch张量与函数表达式写法
人工智能·pytorch·深度学习