我自己的原文哦~ https://blog.51cto.com/whaosoft/13059544
一、PyTorch DDP
正在郁闷呢 jetson nx 的torchvision安装~~ 自带就剩5g 想弄到ssd 项目中的 venv中又 cuda.h没有... 明明已经装好什么都对
算了说今天主题 啊对 还是搬运啊 学习之工具人而已 勿怪
DistributedDataParallel(DDP)是一个支持多机多卡、分布式训练的深度学习工程方法。其能达到略低于卡数的加速比,是目前最流行的多机多卡训练方法。
本文DDP在实际生产中的应用,如在DDP中引入SyncBN,多机多卡环境下的inference加速等。
基本原理与入门:https://zhuanlan.zhihu.com/p/178402798
实现原理与源代码解析:https://zhuanlan.zhihu.com/p/187610959
在过去的两篇文章里,我们已经对DDP的理论、代码进行了充分、详细的介绍,相信大家都已经了然在胸。但是,实践也是很重要的。正所谓理论联系实践,如果只掌握理论而不进行实践,无疑是纸上谈兵。
在这篇文章里,我们通过几个实战例子,来给大家介绍一下DDP在实际生产中的应用。希望能对大家有所帮助!
- 在DDP中引入SyncBN
- DDP下的Gradient Accumulation的进一步加速
- 多机多卡环境下的inference加速
- 保证DDP性能:确保数据的一致性
- 和DDP有关的小技巧
- 控制不同进程的执行顺序
- 避免DDP带来的冗余输出
请欢快地开始阅读吧!
**依赖:**pytorch(gpu)>=1.5,python>=3.6
一. 在DDP中引入SyncBN
什么是Batch Normalization(BN)? 这里就不多加以介绍。附上BN文章(https://arxiv.org/abs/1502.03167)。接下来,让我们来深入了解下BN在多级多卡环境上的完整实现:SyncBN。
什么是SyncBN? SyncBN就是Batch Normalization(BN)。其跟一般所说的普通BN的不同在于工程实现方式:SyncBN能够完美支持多卡训练,而普通BN在多卡模式下实际上就是单卡模式。 我们知道,BN中有moving mean和moving variance这两个buffer,这两个buffer的更新依赖于当前训练轮次的batch数据的计算结果。但是在普通多卡DP模式下,各个模型只能拿到自己的那部分计算结果,所以在DP模式下的普通BN被设计为只利用主卡上的计算结果来计算moving mean和moving variance,之后再广播给其他卡。这样,实际上BN的batch size就只是主卡上的batch size那么大。当模型很大、batch size很小时,这样的BN无疑会限制模型的性能。 为了解决这个问题,PyTorch新引入了一个叫SyncBN的结构,利用DDP的分布式计算接口来实现真正的多卡BN。 SyncBN的原理 SyncBN的原理很简单:SyncBN利用分布式通讯接口在各卡间进行通讯,从而能利用所有数据进行BN计算。为了尽可能地减少跨卡传输量,SyncBN做了一个关键的优化,即只传输各自进程的各自的 小batch mean和 小batch variance,而不是所有数据。具体流程请见下面: 前向传播 在各进程上计算各自的 小batch mean和小batch variance 各自的进程对各自的 小batch mean和小batch variance进行all_gather操作,每个进程都得到s的全局量。 注释:只传递mean和variance,而不是整体数据,可以大大减少通讯量,提高速度。 每个进程分别计算总体mean和总体variance,得到一样的结果 注释:在数学上是可行的,有兴趣的同学可以自己推导一下。 接下来,延续正常的BN计算。 注释:因为从前向传播的计算数据中得到的batch mean和batch variance在各卡间保持一致,所以,running_mean和running_variance就能保持一致,不需要显式地同步了! 后向传播:和正常的一样 贴一下关键代码,有兴趣的同学可以研究下:pytorch源码(https://github.com/pytorch/pytorch/blob/release/1.5/torch/nn/modules/_functions.py#L5) SyncBN与DDP的关系 一句话总结,当前PyTorch SyncBN只在DDP单进程单卡模式中支持。SyncBN用到 all_gather这个分布式计算接口,而使用这个接口需要先初始化DDP环境。 复习一下DDP的伪代码中的准备阶段中的DDP初始化阶段
这里有三个点需要注意: 这里的为可能的SyncBN层做准备,实际上就是检测当前是否是DDP单进程单卡模式,如果不是,会直接停止。 这告诉我们,SyncBN需要在DDP环境初始化后初始化,但是要在DDP模型前就准备好。 为什么当前PyTorch SyncBN只支持DDP单进程单卡模式? 从SyncBN原理中我们可以看到,其强依赖了all_gather计算,而这个分布式接口当前是不支持单进程多卡或者DP模式的。当然,不排除未来也是有可能支持的。 怎么用SyncBN? 怎么样才能在我们的代码引入SyncBN呢?很简单:
# DDP init
dist.init_process_group(backend='nccl')
# 按照原来的方式定义模型,这里的BN都使用普通BN就行了。
model = MyModel()
# 引入SyncBN,这句代码,会将普通BN替换成SyncBN。
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
# 构造DDP模型
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
又是熟悉的模样,像DDP一样,一句代码就解决了问题。这是怎么做到的呢?
convert_sync_batchnorm
的原理:
torch.nn.SyncBatchNorm.convert_sync_batchnorm
会搜索model里面的每一个module,如果发现这个module是、或者继承了torch.nn.modules.batchnorm._BatchNorm
类,就把它替换成SyncBN。也就是说,如果你的Normalization层是自己定义的特殊类,没有继承过 _BatchNorm
类,那么convert_sync_batchnorm
是不支持的,需要你自己实现一个新的SyncBN!
下面给一下convert_sync_batchnorm
的源码(https://github.com/pytorch/pytorch/blob/v1.5.0/torch/nn/modules/batchnorm.py#L474),可以看到convert的过程中,新的SyncBN复制了原来的BN层的所有参数:
@classmethod
def convert_sync_batchnorm(cls, module, process_group=None):
r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
:class:`torch.nn.SyncBatchNorm` layers.
"""
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = torch.nn.SyncBatchNorm(module.num_features,
module.eps, module.momentum,
module.affine,
module.track_running_stats,
process_group)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
del module
return module_output
二. DDP下的Gradient Accumulation的进一步加速什么是Gradient Accmulation?
Gradient Accumulation,即梯度累加,相信大家都有所了解,是一种增大训练时batch size的技术,造福了无数硬件条件窘迫的我等穷人。不了解的同学请看这个知乎链接(https://www.zhihu.com/question/303070254/answer/573037166)。
为什么还能进一步加速?
我们仔细思考一下DDP下的gradient accumulation。
# 单卡模式,即普通情况下的梯度累加
for 每次梯度累加循环
optimizer.zero_grad()
for 每个小step
prediction = model(data)
loss_fn(prediction, label).backward() # 积累梯度,不应用梯度改变
optimizer.step() # 应用梯度改变
我们知道,DDP的gradient all_reduce阶段发生在loss_fn(prediction, label).backward()
。这意味着,在梯度累加的情况下,假设一次梯度累加循环有K个step,每次梯度累加循环会进行K次 all_reduce!但事实上,每次梯度累加循环只会有一次 optimizer.step(),即只应用一次参数修改,这意味着在每一次梯度累加循环中,我们其实只要进行一次gradient all_reduce即可满足要求,有K-1次 all_reduce**被浪费了!**而每次 all_reduce的时间成本是很高的!
如何加速
解决问题的思路在于,对前K-1次step取消其梯度同步。 幸运的是,DDP给我们提供了一个暂时取消梯度同步的context函数 no_sync()
(源代码:https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L548)。在这个context下,DDP不会进行梯度同步。
所以,我们可以这样实现加速:
model = DDP(model)
for 每次梯度累加循环
optimizer.zero_grad()
# 前K-1个step,不进行梯度同步,累积梯度。
for K-1个小step:
with model.no_sync():
prediction = model(data)
loss_fn(prediction, label).backward()
# 第K个step,进行梯度同步
prediction = model(data)
loss_fn(prediction, label).backward()
optimizer.step()
给一个优雅写法(同时兼容单卡、DDP模式哦):
from contextlib import nullcontext
# 如果你的python版本小于3.7,请注释掉上面一行,使用下面这个:
# from contextlib import suppress as nullcontext
if local_rank != -1:
model = DDP(model)
optimizer.zero_grad()
for i, (data, label) in enumerate(dataloader):
# 只在DDP模式下,轮数不是K整数倍的时候使用no_sync
my_context = model.no_sync if local_rank != -1 and i % K != 0 else nullcontext
with my_context():
prediction = model(data)
loss_fn(prediction, label).backward()
if i % K == 0:
optimizer.step()
optimizer.zero_grad()
是不是很漂亮!
三. 多机多卡环境下的inference加速
问题
有一些非常现实的需求,相信大家肯定碰到过:
- 一般,训练中每几个epoch我们会跑一下inference、测试一下模型性能。在DDP多卡训练环境下,能不能利用多卡来加速inference速度呢?
- 我有一堆数据要跑一些网络推理,拿到inference结果。DP下多卡加速比太低,能不能利用DDP多卡来加速呢?
解法
这两个问题实际是同一个问题。答案肯定是可以的,但是,没有现成、省力的方法。
测试和训练的不同在于:
- 测试的时候不需要进行梯度反向传播,inference过程中各进程之间不需要通讯。
- 测试的时候,不同模型的inference结果、性能指标的类型多种多样,没有统一的形式。
- 我们很难定义一个统一的框架,像训练时
model=DDP(model)
那样方便地应用DDP多卡加速。
解决问题的思路很简单,就是各个进程中各自进行单卡的inference,然后把结果收集到一起。单卡inference很简单,我们甚至可以直接用DDP包装前的模型。问题其实只有两个:
- 我们要如何把数据split到各个进程中
- 我们要如何把结果合并到一起
如何把数据split到各个进程中:新的data sampler
大家肯定还记得,在训练的时候,我们用的 torch.utils.data.distributed.DistributedSampler
帮助我们把数据不重复地分到各个进程上去。但是,其分的方法是:每段连续的N个数据,拆成一个一个,分给N个进程,所以每个进程拿到的数据不是连续的。这样,不利于我们在inference结束的时候将结果合并到一起。
所以,这里我们需要实现一个新的data sampler。它的功能,是能够连续地划分数据块,不重复地分到各个进程上去。直接给代码:
# 来源:https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py
class SequentialDistributedSampler(torch.utils.data.sampler.Sampler):
"""
Distributed Sampler that subsamples indicies sequentially,
making it easier to collate all results at the end.
Even though we only use this sampler for eval and predict (no training),
which means that the model params won't have to be synced (i.e. will not hang
for synchronization even if varied number of forward passes), we still add extra
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, batch_size, rank=None, num_replicas=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.batch_size = batch_size
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += [indices[-1]] * (self.total_size - len(indices))
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
return iter(indices)
def __len__(self):
return self.num_samples
如何把结果合并到一起: all_gather
通过torch.distributed
提供的分布式接口all_gather
,我们可以把各个进程的prediction结果集中到一起。
难点就在这里。因为世界上存在着千奇百怪的神经网络模型,有着千奇百怪的输出,所以,把数据集中到一起不是一件容易的事情。**但是,如果你的网络输出在不同的进程中有着一样的大小,那么这个问题就好解多了。**下面给一个方法,其要求网络的prediction结果在各个进程中的大小是一模一样的:
# 合并结果的函数
# 1. all_gather,将各个进程中的同一份数据合并到一起。
# 和all_reduce不同的是,all_reduce是平均,而这里是合并。
# 2. 要注意的是,函数的最后会裁剪掉后面额外长度的部分,这是之前的SequentialDistributedSampler添加的。
# 3. 这个函数要求,输入tensor在各个进程中的大小是一模一样的。
def distributed_concat(tensor, num_total_examples):
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
return concat[:num_total_examples]
完整的流程
结合上面的介绍,我们可以得到下面这样一个完整的流程。
## 构造测试集
# 假定我们的数据集是这个
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
my_testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
# 使用我们的新sampler
test_sampler = SequentialDistributedSampler(my_testset, batch_size=16)
testloader = torch.utils.data.DataLoader(my_testset, batch_size=16, sampler=test_sampler)
# DDP和模型初始化,略。
# ......
# 正式训练和evaluation
for epoch in range(total_epoch_size):
# 训练代码,略
# .......
# 开始测试
with torch.no_grad():
# 1. 得到本进程的prediction
predictions = []
labels = []
for data, label in testloader:
data, label = data.to(local_rank), label.to(local_rank)
predictions.append(model(data))
labels.append(label)
# 进行gather
predictions = distributed_concat(torch.concat(predictions, dim=0),
len(test_sampler.dataset))
labels = distributed_concat(torch.concat(labels, dim=0),
len(test_sampler.dataset))
# 3. 现在我们已经拿到所有数据的predictioin结果,进行evaluate!
my_evaluate_func(predictions, labels)
更简化的解法
- 如果我们的目的只是得到性能数字,那么,我们甚至可以直接在各个进程中计算各自的性能数字,然后再合并到一起。上面给的解法,是为了更通用的情景。一切根据你的需要来定!
- 我们可以单向地把predictions、labels集中到 rank=0的进程,只在其进行evaluation并输出。PyTorch也提供了相应的接口(链接:https://pytorch.org/docs/stable/distributed.html,send和recv)。
四. 保证DDP性能:确保数据的一致性性能期望
从原理上讲,当没有开启SyncBN时,(或者更严格地讲,没有BN层;但一般有的话影响也不大),以下两种方法训练出来的模型应该是性能相似的:
- 进程数为N的DDP训练
- accumulation为N、其他配置完全相同的单卡训练
如果我们发现性能对不上,那么,往往是DDP中的某些设置出了问题。在DDP系列第二篇中,我们介绍过一个check list,可以根据它检查下自己的配置。其中,在造成性能对不齐的原因中,最有可能的是数据方面出现了问题。
DDP训练时,数据的一致性必须被保证:各个进程拿到的数据,要像是accumulation为N、其他配置完全相同的单卡训练中同个accumulation循环中不同iteration拿到的数据。想象一下,如果各个进程拿到的数据是一样的,或者分布上有任何相似的地方,那么,这就会造成训练数据质量的下降,最终导致模型性能下降。
容易错的点:随机数种子
为保证实验的可复现性,一般我们会在代码在开头声明一个固定的随机数种子,从而使得同一个配置下的实验,无论启动多少次,都会拿到同样的结果。
import random
import numpy as np
import torch
def init_seeds(seed=0, cuda_deterministic=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
def main():
# 一般都直接用0作为固定的随机数种子。
init_seeds(0)
但是在DDP训练中,如果还是像以前一样,使用0作为随机数种子,不做修改,就会造成以下后果:
- DDP的N个进程都使用同一个随机数种子
- 在生成数据时,如果我们使用了一些随机过程的数据扩充方法,那么,各个进程生成的数据会带有一定的同态性。
- 比如说,YOLOv5会使用mosaic数据增强(从数据集中随机采样3张图像与当前的拼在一起,组成一张里面有4张小图的大图)。这样,因为各卡使用了相同的随机数种子,你会发现,各卡生成的图像中,除了原本的那张小图,其他三张小图都是一模一样的!
- 同态性的数据,降低了训练数据的质量,也就降低了训练效率!最终得到的模型性能,很有可能是比原来更低的。
所以,我们需要给不同的进程分配不同的、固定的随机数种子:
def main():
rank = torch.distributed.get_rank()
# 问题完美解决!
init_seeds(1 + rank)
五. 和DDP有关的小技巧控制不同进程的执行顺序
一般情况下,各个进程是各自执行的,速度有快有慢,只有在gradient all-reduce的时候,快的进程才会等一下慢的进程,也就是进行同步。那么,如果我们需要在其他地方进行同步呢?比如说,在加载数据前,如果数据集不存在,我们要下载数据集:
- 我们只需要在唯一一个进程中开启一次下载
- 我们需要让其他进程等待其下载完成,再去加载数据
怎么解决这个问题呢?torch.distributed
提供了一个barrier()
的接口,利用它我们可以同步各个DDP中的各个进程!当使用barrier函数时,DDP进程会在函数的位置进行等待,知道所有的进程都跑到了 barrier函数的位置,它们才会再次向下执行。
只在某进程执行,无须同步:
这是最简单的,只需要一个简单的判断,用不到barrier()
if rank == 0:
code_only_run_in_rank_0()
简单的同步:
没什么好讲的,只是一个示范
code_before()
# 在这一步同步
torch.distributed.barrier()
code_after()
在某个进程中执行A操作,其他进程等待其执行完成后再执行B操作:
也简单。
if rank == 0:
do_A()
torch.distributed.barrier()
else:
do_B()
torch.distributed.barrier()
在某个进程中优先执行A操作,其他进程等待其执行完成后再执行A操作:
这个值得深入讲一下,因为这个是非常普遍的需求。利用contextlib.contextmanager
,我们可以把这个逻辑给优雅地包装起来!
from contextlib import contextmanager
@contextmanager
def torch_distributed_zero_first(rank: int):
"""Decorator to make all processes in distributed training wait for each local_master to do something.
"""
if rank not in [-1, 0]:
torch.distributed.barrier()
# 这里的用法其实就是协程的一种哦。
yield
if rank == 0:
torch.distributed.barrier()
然后我们就可以这样骚操作:
with torch_distributed_zero_first(rank):
if not check_if_dataset_exist():
download_dataset()
load_dataset()
优雅地解决了需求!
避免DDP带来的冗余输出
问题:
当我们在自己的模型中加入DDP模型时,第一的直观感受肯定是,终端里的输出变成了N倍了。这是因为我们现在有N个进程在同时跑整个程序。这不光是对有洁癖的同学造成困扰,其实对所有人都会造成困扰。因为各个进程的速度并不一样快,在茫茫的输出海洋中,我们难以debug、把控实验状态。
解法:
那么,有什么办法能避免这个现象呢?下面,笔者给一个可行的方法:**logging模块+输出信息等级控制。**即用logging输出代替所有print输出,并给不同进程设置不同的输出等级,只在0号进程保留低等级输出。举一个例子:
import logging
# 给主要进程(rank=0)设置低输出等级,给其他进程设置高输出等级。
logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN)
# 普通log,只会打印一次。
logging.info("This is an ordinary log.")
# 危险的warning、error,无论在哪个进程,都会被打印出来,从而方便debug。
logging.error("This is a fatal log!")
simple but powerful!
二、PyTorch~SyncBatchNorm
对于一些模型占用显存很大,导致可以上的 batch size 很小这类任务来说,分布式训练的时候就需要用 SyncBatchNorm 来使得统计量更加的准确。本文对SyncBatchNorm的前向以及反向实现细节进行阐述。
我们知道在分布式数据并行多卡训练的时候,BatchNorm 的计算过程(统计均值和方差)在进程之间是独立的,也就是每个进程只能看到本地 GlobalBatchSize / NumGpu 大小的数据。对于一般的视觉任务比如分类,分布式训练的时候,单卡的 batch size 也足够大了,所以不需要在计算过程中同步 batchnorm 的统计量,因为同步也会让训练效率下降。但是对于一些模型占用显存很大,导致可以上的 batch size 很小这类任务来说,分布式训练的时候就需要用 SyncBatchNorm 来使得统计量更加的准确。
SyncBatchNorm 前向实现
前向第一步,计算本地均值和方差
假设在4张GPU上做分布式数据并行训练,我们来看下各个进程上 SyncBN 的行为:
如上图所示,SyncBN前向实现的第一步是,每个GPU先单独计算各自本地数据 X_i
对应均值和方差(mean_i
和 var_i
) 。
而计算均值和方差的 CUDA kernel 具体实现是实现采用的 Welford
迭代计算算法
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
我们知道传统方法计算均值,是要先把所有数据加起来然后再除以个数,而方差是在平均值的基础上做进一步的计算。
但是这样的计算方式有个问题是,在数据量非常之大的情况下,把所有数相加的结果是一个非常大的值,容易导致精度溢出。
而Welford
迭代计算算法,则只需要对数据集进行单次遍历,然后根据迭代公式计算均值,可以避免传统算法可能导致的精度溢出的问题,且 Welford
算法可以并行化。
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
假设现在输入张量形状是 (B,C,H,W)
,下面解释输入张量是 NCHW 格式的时候, CUDA kernel 具体开启线程的方式和每个线程具体计算细节。
由于线程的配置是按照固定的公式计算出来的,这里为了解释方便就固定为其中一种情况:
如上图所示,总共起了 C
个 thread block,也就是 grid
大小等于通道数。每个 thread block 负责计算某一个通道的均值和方差。
每个 thread block 的形状是两维,x维度是 512, y 维度是 1,共处理 B * H * W
大小的数据,其中数据组织形式是 x 方向是 H * W
维度,y 方向是 B
维度。
每个thread block 负责处理的数据大小和其中每个线程负责处理的位置,如下图所示:
如上图所示紫色方块表示thread block中的一个thread,紫色箭头指向表示,在kernel执行过程中,该线程所要负责处理的数据。
每个线程在x方向上每处理完一个数据,移动的步长为 blockDim.x=512
,x方向遍历完之后,y方向移动步长为blockDim.y=1
,以此类推。
kernel 执行的第一步就是,所有线程处理完自己所负责的数据,然后同步一下,接着就是合并每个线程计算得到的局部均值和方差。
而我们知道一个 thread block 内的线程,是按全局 id 顺序从0开始每 32 个线程分为一组,也就是一个 warp,然后以warp为单位来执行。
kernel 执行的第二步就是 ,每个 warp 内的线程合并均值和方差,通过 warp 级的同步元语库函数 __shfl_xor_sync
来实现 warp 内线程结果的合并。
这一步做完之后,warp 内每个线程都包含了合并之后的均值和方差,下面解释如何通过 __shfl_xor_sync
来实现 warp 内线程结果合并的:
上图中的每一行的32个黑点表示一个 warp 内的32个线程,上方的id 表示每个线程在warp内的id。
然后我们看合并 mean 和 var 的循环,这里可视化了每个循环内线程之间的交互。
__shfl_xor_sync
简单来理解,只需要关注第 2 和 3 个参数,第二个参数是线程之间要交换的值,第三个参数传 i。
具体作用就是,当前线程的 id 和 这个 i 做异或 xor
位运算,计算得到的结果也是 id ,也就是当前线程要与哪个线程交换值。
当 i = 1
的时候,
对于线程 id 0 和 1, 0 xor 1 = 1
, 1 xor 1 = 0
,则就是线程 0 和 1 交换各自的均值和方差,然后就都持有了合并之后的均值和方差了。
再看线程 id 2 和 3, 2 xor 1 = 3
,3 oxr 1 = 2
,所以 2 和 3 交换。
同理可得第一轮循环,是线程按顺序2个为一组组内合并。
当 i = 2
的时候,
对于线程 id 0 和 2, 0 xor 2 = 2
, 2 xor 2 = 0
,
对于线程 id 1 和 3,1 xor 2 = 3
, 3 xor 2 = 1
所以交换完合并之后,thread 0 ~ 3 就都持有了这4个线程合并之后的均值和方差了。
同理可得,
i = 2
的时候线程按顺序4个为一组,组内根据异或运算计算交换线程对合并均值和方差。
i = 4
的时候,线程按顺序8个为一组,
i = 8
的时候,线程按顺序16个为一组,
当最后一轮 i = 16
循环完了之后,warp 内每个线程就都持有了该 warp 的所有线程合并的均值和方差了。
kernel 执行的最后一步是,上面每个 warp 内结果合并完,会做一次全局的线程同步。之后再将所有 warp 的结果合并就得到该 thread block 所负责计算的通道均值和方差了。
前向第二步,GPU之间同步均值和方差
通过集合通信操作 AllGather
让每个 GPU 上的进程都拿到所有 GPU 上的均值和方差,最后就是每个GPU内计算得到全局的均值和方差,同时更新 running_mean
和 running_var
前向第三步,计算 SyncBN 的输出
最后这一步就一个常规的batchnorm操作,对输入 x 做 normalize 操作得到输出,cuda kernel 就是一个 eltwise 的操作,因为不需要计算均值和方差了。这里就不展开了,有兴趣的读者可以看文末的参考链接,去阅读torch的源码,也可以学习一下对于 NHWC
格式的 cuda kernel 是如何实现的。
SyncBatchNorm 反向实现细节
BatchNorm 反向计算公式
首先复习一下 BatchNorm 反向,输入格式是 (B,C,H,W)
则某个通道(通道索引 c
)对应的 输入 x 、weight 和 bias 梯度计算公式,这里不做推导只列出公式:
前置公式:
输出梯度为 y_grad
weight 对应通道 c 的梯度:
bias 对应通道 c 的梯度:
输入 x 对应通道 c 上某个位置 b, h, w 的梯度:
反向计算流程
每个GPU都计算出本地对应的 weight_grad
,bias_grad
,sum_dy
和 sum_dy_xmu
,具体CUDA kernel 实现思路和前向第一步类似,这里就不展开了,有兴趣可以去阅读源码。
由于分布式数据并行下,权值的梯度会自动做全局同步,所以 SyncBN 就不需要管权值梯度的跨 GPU 的同步。
而对于sum_dy
和 sum_dy_xmu
,则通过集合通信操作 AllReduce
将所有GPU上的结果累加,使得每个GPU上持有全局累加的结果。
最后每个 GPU 根据上面的计算公式计算本地输入x对应的梯度,但是需要注意的是,由于 sum_dy
和 sum_dy_xmu
是跨 GPU 全局累加的结果,所以上面公式中的 rc=B*H*W
要改为 rc=B*H*W*num_gpu
。该 CUDA kernel 的实现,根据上述公式,也是一个 eltiwse 的操作,细节可以去阅读torch源码。
参考资料
- https://hangzhang.org/PyTorch-Encoding/tutorials/syncbn.html
- https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html
- https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cuh
- https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Normalization.cu
- https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/
- https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#simt-architecture
- https://people.maths.ox.ac.uk/gilesm/cuda/2019/lecture_04.pdf
- https://mpitutorial.com/tutorials/mpi-scatter-gather-and-allgather/
三、PyTorch~NeRF
笔者通过整理分析了NeRF论文和相关参考代码,将为读者朋友讲述利用PyTorch框架,从0到1简单复现一个NeRF(神经辐射场)的实现细节和过程。
在解释代码之前,首先对NeRF(神经辐射场)的原理与含义进行简单回顾。而NeRF论文中是这样解释NeRF算法流程的:
"我们提出了一个当前最优的方法,应用于复杂场景下合成新视图的任务,具体的实现原理是使用一个稀疏的输入视图集合,然后不断优化底层的连续体素场景函数。我们的算法,使用一个全连接(非卷积)的深度网络,表示一个场景,这个深度网络的输入是一个单独的5D坐标(空间位置(x,y,z)和视图方向(xita,sigma)),其对应的输出则是体素密度和视图关联的辐射向量。我们通过查询沿着相机射线的5D坐标合成新的场景视图,以及通过使用经典的体素渲染技术将输出颜色和密度投射到图像中。因为体素渲染具有天然的可变性,所以优化我们的表示方法所需的唯一输入就是一组已知相机位姿的图像。我们介绍如何高效优化神经辐射场照度,以渲染具有复杂几何形状和外观的逼真新颖视图,并展示了由于之前神经渲染和视图合成工作的结果。"
图1|NeRF实现流程
基于前文的原理,本节开始讲述具体的代码实现。首先,导入算法需要的Python库文件。
import os
from typing import Optional,Tuple,List,Union,Callable
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from mpl\_toolkits.mplot3d import axes3d
from tqdm import trange
# 设置GPU还是CPU设备
device = torch.device\('cuda' if torch.cuda.is\_available\(\) else 'cpu'\)
1 输入
根据相关论文中的介绍可知,NeRF的输入是一个包含空间位置坐标与视图方向的5D坐标。然而,在PyTorch构建NeRF过程中使用的数据集只是一般的3D到2D图像数据集,包含拍摄相机的内参:位姿和焦距。因此在后面的操作中,我们会把输入数据集转为算法模型需要的输入形式。
在这一流程中使用乐高推土机图像作为简单NeRF算法的数据集,如图2所示:(具体的数据链接请在文末查看)
图2|乐高推土机数据集
这项工作中使用的小型乐高数据集由 106 幅乐高推土机的图像组成,并配有位姿数据和常用焦距数值。与其他数据集一样,这里保留前 100 张图像用于训练,并保留一张测试图像用于验证,具体的加载数据操作如下:
data = np.load\('tiny\_nerf\_data.npz'\) # 加载数据集
images = data\['images'\] # 图像数据
poses = data\['poses'\] # 位姿数据
focal = data\['focal'\] # 焦距数值
print\(f'Images shape: \{images.shape\}'\)
print\(f'Poses shape: \{poses.shape\}'\)
print\(f'Focal length: \{focal\}'\)
height, width = images.shape\[1:3\]
near, far = 2., 6.
n\_training = 100 # 训练数据数量
testimg\_idx = 101 # 测试数据下标
testimg, testpose = images\[testimg\_idx\], poses\[testimg\_idx\]
plt.imshow\(testimg\)
print\('Pose'\)
print\(testpose\)
2 数据处理
一般而言,为了收集这些特点输入数据,算法中需要对输入图像进行反渲染操作。具体来讲就是通过每个像素点在三维空间中绘制投影线,并从中提取样本。
要从图像以外的三维空间采样输入数据点,首先就得从乐高照片集中获取每台相机的初始位姿,然后通过一些矢量数学运算,将这些4x4姿态矩阵转换成「表示原点的三维坐标和表示方向的三维矢量」------这两类信息最终会结合起来描述一个矢量,该矢量用以表征拍摄照片时相机的指向。
下列代码则正是通过绘制箭头来描述这一操作,箭头表示每一帧图像的原点和方向:
# 方向数据
dirs = np.stack\(\[np.sum\(\[0, 0, -1\] \* pose\[:3, :3\], axis=-1\) for pose in poses\]\)
# 原点数据
origins = poses\[:, :3, -1\]
# 绘图的设置
ax = plt.figure\(figsize=\(12, 8\)\).add\_subplot\(projection='3d'\)
\_ = ax.quiver\(
origins\[..., 0\].flatten\(\),
origins\[..., 1\].flatten\(\),
origins\[..., 2\].flatten\(\),
dirs\[..., 0\].flatten\(\),
dirs\[..., 1\].flatten\(\),
dirs\[..., 2\].flatten\(\), length=0.5, normalize=True\)
ax.set\_xlabel\('X'\)
ax.set\_ylabel\('Y'\)
ax.set\_zlabel\('z'\)
plt.show\(\)
最终绘制出来的箭头结果如下图所示:
图3|采样点相机拍摄指向
当有了这些相机位姿数据之后,我们就可以沿着图像的每个像素找到投影线,而每条投影线都是由其原点(x,y,z)和方向联合定义。其中每个像素的原点可能相同,但方向一般是不同的。这些方向射线都略微偏离中心,因此不会存在两条平行方向线,如下图所示:
图4|相机内参示意图
根据图4所述的原理,我们就可以确定每条射线的方向和原点,相关代码如下:
def get\_rays\(
height: int, # 图像高度
width: int, # 图像宽带
focal\_length: float, # 焦距
c2w: torch.Tensor
\) -> Tuple\[torch.Tensor, torch.Tensor\]:
"""
通过每个像素和相机原点,找到射线的原点和方向。
"""
# 应用针孔相机模型收集每个像素的方向
i, j = torch.meshgrid\(
torch.arange\(width, dtype=torch.float32\).to\(c2w\),
torch.arange\(height, dtype=torch.float32\).to\(c2w\),
indexing='ij'\)
i, j = i.transpose\(-1, -2\), j.transpose\(-1, -2\)
# 方向数据
directions = torch.stack\(\[\(i - width \* .5\) / focal\_length,
-\(j - height \* .5\) / focal\_length,
-torch.ones\_like\(i\)
\], dim=-1\)
# 用相机位姿求出方向
rays\_d = torch.sum\(directions\[..., None, :\] \* c2w\[:3, :3\], dim=-1\)
# 默认所有射线原点相同
rays\_o = c2w\[:3, -1\].expand\(rays\_d.shape\)
return rays\_o, rays\_d
得到每个像素对应的射线的方向数据和原点数据之后,就能够获得了NeRF算法中需要的五维数据输入,下面将这些数据调整为算法输入的格式:
# 转为PyTorch的tensor
images = torch.from\_numpy\(data\['images'\]\[:n\_training\]\).to\(device\)
poses = torch.from\_numpy\(data\['poses'\]\).to\(device\)
focal = torch.from\_numpy\(data\['focal'\]\).to\(device\)
testimg = torch.from\_numpy\(data\['images'\]\[testimg\_idx\]\).to\(device\)
testpose = torch.from\_numpy\(data\['poses'\]\[testimg\_idx\]\).to\(device\)
# 针对每个图像获取射线
height, width = images.shape\[1:3\]
with torch.no\_grad\(\):
ray\_origin, ray\_direction = get\_rays\(height, width, focal, testpose\)
print\('Ray Origin'\)
print\(ray\_origin.shape\)
print\(ray\_origin\[height // 2, width // 2, :\]\)
print\(''\)
print\('Ray Direction'\)
print\(ray\_direction.shape\)
print\(ray\_direction\[height // 2, width // 2, :\]\)
print\(''\)
2.1 分层采样
当算法输入模块有了NeRF算法需要的输入数据,也就是包含原点和方向向量组合的线条时,就可以在线条上进行采样。这一过程是采用从粗到细的采样策略,即分层采样策略。
具体来说,分层采样就是将光线分成均匀分布的小块,接着在每个小块内随机抽样。其中扰动的设置决定了是均匀取样的,还是直接简单使用分区中心作为采样点。具体操作代码如下所示:
# 采样函数定义
def sample\_stratified\(
rays\_o: torch.Tensor, # 射线原点
rays\_d: torch.Tensor, # 射线方向
near: float,
far: float,
n\_samples: int, # 采样数量
perturb: Optional\[bool\] = True, # 扰动设置
inverse\_depth: bool = False # 反向深度
\) -> Tuple\[torch.Tensor, torch.Tensor\]:
"""
从规则的bin中沿着射线进行采样。
"""
# 沿着射线抓取采样点
t\_vals = torch.linspace\(0., 1., n\_samples, device=rays\_o.device\)
if not inverse\_depth:
# 由远到近线性采样
z\_vals = near \* \(1.-t\_vals\) + far \* \(t\_vals\)
else:
# 在反向深度中线性采样
z\_vals = 1./\(1./near \* \(1.-t\_vals\) + 1./far \* \(t\_vals\)\)
# 沿着射线从bins中统一采样
if perturb:
mids = .5 \* \(z\_vals\[1:\] + z\_vals\[:-1\]\)
upper = torch.concat\(\[mids, z\_vals\[-1:\]\], dim=-1\)
lower = torch.concat\(\[z\_vals\[:1\], mids\], dim=-1\)
t\_rand = torch.rand\(\[n\_samples\], device=z\_vals.device\)
z\_vals = lower + \(upper - lower\) \* t\_rand
z\_vals = z\_vals.expand\(list\(rays\_o.shape\[:-1\]\) + \[n\_samples\]\)
# 应用相应的缩放参数
pts = rays\_o\[..., None, :\] + rays\_d\[..., None, :\] \* z\_vals\[..., :, None\]
return pts, z\_vals
接着就到了对这些采样点做可视化分析的步骤。如图5中所述,未受扰动的蓝 色点是bin的"中心",而红点对应扰动点的采样。请注意,红点与上方的蓝点略有偏移,但所有点都在远近采样设定值之间。具体代码如下:
y\_vals = torch.zeros\_like\(z\_vals\)
# 调用采样策略函数
\_, z\_vals\_unperturbed = sample\_stratified\(rays\_o, rays\_d, near, far, n\_samples,
perturb=False, inverse\_depth=inverse\_depth\)
# 绘图相关
plt.plot\(z\_vals\_unperturbed\[0\].cpu\(\).numpy\(\), 1 + y\_vals\[0\].cpu\(\).numpy\(\), 'b-o'\)
plt.plot\(z\_vals\[0\].cpu\(\).numpy\(\), y\_vals\[0\].cpu\(\).numpy\(\), 'r-o'\)
plt.ylim\(\[-1, 2\]\)
plt.title\('Stratified Sampling \(blue\) with Perturbation \(red\)'\)
ax = plt.gca\(\)
ax.axes.yaxis.set\_visible\(False\)
plt.grid\(True\)
图5|采样结果示意图
3 位置编码
与Transformer一样,NeRF也使用了位置编码器。因此NeRF就需要借助位置编码器将输入映射到更高的频率空间,以弥补神经网络在学习低频函数时的偏差。
这一环节将会为位置编码器建立一个简单的 torch.nn.Module 模块,相同的编码器可同时用于对输入样本和视图方向的编码操作。注意,这些输入被指定了不同的参数。代码如下所示:
# 位置编码类
class PositionalEncoder\(nn.Module\):
"""
对输入点,做sine或者consine位置编码。
"""
def \_\_init\_\_\(
self,
d\_input: int,
n\_freqs: int,
log\_space: bool = False
\):
super\(\).\_\_init\_\_\(\)
self.d\_input = d\_input
self.n\_freqs = n\_freqs
self.log\_space = log\_space
self.d\_output = d\_input \* \(1 + 2 \* self.n\_freqs\)
self.embed\_fns = \[lambda x: x\]
# 定义线性或者log尺度的频率
if self.log\_space:
freq\_bands = 2.\*\*torch.linspace\(0., self.n\_freqs - 1, self.n\_freqs\)
else:
freq\_bands = torch.linspace\(2.\*\*0., 2.\*\*\(self.n\_freqs - 1\), self.n\_freqs\)
# 替换sin和cos
for freq in freq\_bands:
self.embed\_fns.append\(lambda x, freq=freq: torch.sin\(x \* freq\)\)
self.embed\_fns.append\(lambda x, freq=freq: torch.cos\(x \* freq\)\)
def forward\(
self,
x
\) -> torch.Tensor:
"""
实际使用位置编码的函数。
"""
return torch.concat\(\[fn\(x\) for fn in self.embed\_fns\], dim=-1\)
4 NeRF模型
在此,定义一个NeRF 模型------主要由线性层模块列表构成,而列表中进一步包含非线性激活函数和残差连接。该模型有一个可选的视图方向输入,如果在实例化时提供具体的方向信息,那么会改变模型结构。
(本实现基于原始论文NeRF:Representing Scenes as Neural Radiance Fields for View Synthesis 的第3节,并使用相同的默认设置)
具体代码如下所示:
# 定义NeRF模型
class NeRF\(nn.Module\):
"""
神经辐射场模块。
"""
def \_\_init\_\_\(
self,
d\_input: int = 3,
n\_layers: int = 8,
d\_filter: int = 256,
skip: Tuple\[int\] = \(4,\),
d\_viewdirs: Optional\[int\] = None
\):
super\(\).\_\_init\_\_\(\)
self.d\_input = d\_input # 输入
self.skip = skip # 残差连接
self.act = nn.functional.relu # 激活函数
self.d\_viewdirs = d\_viewdirs # 视图方向
# 创建模型的层结构
self.layers = nn.ModuleList\(
\[nn.Linear\(self.d\_input, d\_filter\)\] +
\[nn.Linear\(d\_filter + self.d\_input, d\_filter\) if i in skip \\
else nn.Linear\(d\_filter, d\_filter\) for i in range\(n\_layers - 1\)\]
\)
# Bottleneck 层
if self.d\_viewdirs is not None:
# 如果使用视图方向,分离alpha和RGB
self.alpha\_out = nn.Linear\(d\_filter, 1\)
self.rgb\_filters = nn.Linear\(d\_filter, d\_filter\)
self.branch = nn.Linear\(d\_filter + self.d\_viewdirs, d\_filter // 2\)
self.output = nn.Linear\(d\_filter // 2, 3\)
else:
# 如果不使用试图方向,则简单输出
self.output = nn.Linear\(d\_filter, 4\)
def forward\(
self,
x: torch.Tensor,
viewdirs: Optional\[torch.Tensor\] = None
\) -> torch.Tensor:
r"""
带有视图方向的前向传播
"""
# 判断是否设置视图方向
if self.d\_viewdirs is None and viewdirs is not None:
raise ValueError\('Cannot input x\_direction if d\_viewdirs was not given.'\)
# 运行bottleneck层之前的网络层
x\_input = x
for i, layer in enumerate\(self.layers\):
x = self.act\(layer\(x\)\)
if i in self.skip:
x = torch.cat\(\[x, x\_input\], dim=-1\)
# 运行 bottleneck
if self.d\_viewdirs is not None:
# Split alpha from network output
alpha = self.alpha\_out\(x\)
# 结果传入到rgb过滤器
x = self.rgb\_filters\(x\)
x = torch.concat\(\[x, viewdirs\], dim=-1\)
x = self.act\(self.branch\(x\)\)
x = self.output\(x\)
# 拼接alpha一起作为输出
x = torch.concat\(\[x, alpha\], dim=-1\)
else:
# 不拼接,简单输出
x = self.output\(x\)
return x
5 体积渲染
上面得到NeRF模型的输出结果之后,仍需将NeRF的输出转换成图像。也就是通过渲染模块对每个像素沿光线方向的所有样本进行加权求和,从而得到该像素的估计颜色值,此外每个RGB样本都会根据其Alpha值进行加权。其中Alpha值越高,表明采样区域不透明的可能性越大,因此沿射线方向越远的点越有可能被遮挡,累加乘积可确保更远处的点受到抑制。具体代码如下:
# 体积渲染
def cumprod\_exclusive\(
tensor: torch.Tensor
\) -> torch.Tensor:
"""
\(Courtesy of https://github.com/krrish94/nerf-pytorch\)
和tf.math.cumprod\(..., exclusive=True\)功能类似
参数:
tensor \(torch.Tensor\): Tensor whose cumprod \(cumulative product, see \`torch.cumprod\`\) along dim=-1
is to be computed.
返回值:
cumprod \(torch.Tensor\): cumprod of Tensor along dim=-1, mimiciking the functionality of
tf.math.cumprod\(..., exclusive=True\) \(see \`tf.math.cumprod\` for details\).
"""
# 首先计算规则的cunprod
cumprod = torch.cumprod\(tensor, -1\)
cumprod = torch.roll\(cumprod, 1, -1\)
# 用1替换首个元素
cumprod\[..., 0\] = 1.
return cumprod
# 输出到图像的函数
def raw2outputs\(
raw: torch.Tensor,
z\_vals: torch.Tensor,
rays\_d: torch.Tensor,
raw\_noise\_std: float = 0.0,
white\_bkgd: bool = False
\) -> Tuple\[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor\]:
"""
将NeRF的输出转换为RGB输出。
"""
# 沿着\`z\_vals\`轴元素之间的差值.
dists = z\_vals\[..., 1:\] - z\_vals\[..., :-1\]
dists = torch.cat\(\[dists, 1e10 \* torch.ones\_like\(dists\[..., :1\]\)\], dim=-1\)
# 将每个距离乘以相应方向射线的法线,转换为现实世界中的距离(考虑非单位方向)。
dists = dists \* torch.norm\(rays\_d\[..., None, :\], dim=-1\)
# 为模型预测密度添加噪音。可用于在训练过程中对网络进行正则化(防止出现浮点伪影)。
noise = 0.
if raw\_noise\_std > 0.:
noise = torch.randn\(raw\[..., 3\].shape\) \* raw\_noise\_std
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point. \[n\_rays, n\_samples\]
alpha = 1.0 - torch.exp\(-nn.functional.relu\(raw\[..., 3\] + noise\) \* dists\)
# 预测每条射线上每个样本的密度。数值越大,表示该点被吸收的可能性越大。\[n\_ 射线,n\_样本]
weights = alpha \* cumprod\_exclusive\(1. - alpha + 1e-10\)
# 计算RGB图的权重。
rgb = torch.sigmoid\(raw\[..., :3\]\) # \[n\_rays, n\_samples, 3\]
rgb\_map = torch.sum\(weights\[..., None\] \* rgb, dim=-2\) # \[n\_rays, 3\]
# 估计预测距离的深度图。
depth\_map = torch.sum\(weights \* z\_vals, dim=-1\)
# 稀疏图
disp\_map = 1. / torch.max\(1e-10 \* torch.ones\_like\(depth\_map\),
depth\_map / torch.sum\(weights, -1\)\)
# 沿着每条射线加权。
acc\_map = torch.sum\(weights, dim=-1\)
# 要合成到白色背景上,请使用累积的 alpha 贴图。
if white\_bkgd:
rgb\_map = rgb\_map + \(1. - acc\_map\[..., None\]\)
return rgb\_map, depth\_map, acc\_map, weights
6 分层体积采样
事实上,三维空间中的遮挡物非常稀疏,因此大多数点对渲染图像的贡献不大。所以,对积分有贡献的区域进行超采样会有更好的效果。这里,笔者对第一组样本应用基于归一化的权重来创建整个光线的概率密度函数,然后对该密度函数应用反变换采样来收集第二组样本。具体代码如下:
# 采样概率密度函数
def sample\_pdf\(
bins: torch.Tensor,
weights: torch.Tensor,
n\_samples: int,
perturb: bool = False
\) -> torch.Tensor:
"""
应用反向转换采样到一组加权点。
"""
# 正则化权重得到概率密度函数。
pdf = \(weights + 1e-5\) / torch.sum\(weights + 1e-5, -1, keepdims=True\) # \[n\_rays, weights.shape\[-1\]\]
# 将概率密度函数转为累计分布函数。
cdf = torch.cumsum\(pdf, dim=-1\) # \[n\_rays, weights.shape\[-1\]\]
cdf = torch.concat\(\[torch.zeros\_like\(cdf\[..., :1\]\), cdf\], dim=-1\) # \[n\_rays, weights.shape\[-1\] + 1\]
# 从累计分布函数中提取样本位置。perturb == 0 时为线性。
if not perturb:
u = torch.linspace\(0., 1., n\_samples, device=cdf.device\)
u = u.expand\(list\(cdf.shape\[:-1\]\) + \[n\_samples\]\) # \[n\_rays, n\_samples\]
else:
u = torch.rand\(list\(cdf.shape\[:-1\]\) + \[n\_samples\], device=cdf.device\) # \[n\_rays, n\_samples\]
# 沿累计分布函数找出 u 值所在的索引。
u = u.contiguous\(\) # 返回具有相同值的连续张量。
inds = torch.searchsorted\(cdf, u, right=True\) # \[n\_rays, n\_samples\]
# 夹住超出范围的索引。
below = torch.clamp\(inds - 1, min=0\)
above = torch.clamp\(inds, max=cdf.shape\[-1\] - 1\)
inds\_g = torch.stack\(\[below, above\], dim=-1\) # \[n\_rays, n\_samples, 2\]
# 从累计分布函数和相应的 bin 中心取样。
matched\_shape = list\(inds\_g.shape\[:-1\]\) + \[cdf.shape\[-1\]\]
cdf\_g = torch.gather\(cdf.unsqueeze\(-2\).expand\(matched\_shape\), dim=-1,
index=inds\_g\)
bins\_g = torch.gather\(bins.unsqueeze\(-2\).expand\(matched\_shape\), dim=-1,
index=inds\_g\)
# 将样本转换为射线长度。
denom = \(cdf\_g\[..., 1\] - cdf\_g\[..., 0\]\)
denom = torch.where\(denom \< 1e-5, torch.ones\_like\(denom\), denom\)
t = \(u - cdf\_g\[..., 0\]\) / denom
samples = bins\_g\[..., 0\] + t \* \(bins\_g\[..., 1\] - bins\_g\[..., 0\]\)
return samples # \[n\_rays, n\_samples\]
7 整体的前向传播流程
此时应将上面所有内容整合在一起,通过模型计算一次前向传递。
由于潜在的内存问题,前向传递以"块"为单位进行计算,然后汇总到一个批次中。梯度传播是在整个批次处理完毕后进行的,因此有"块"和"批次"之分。对于内存紧张环境来说,分块处理尤为重要,因为该环境下提供的资源比原始论文中引用的资源更为有限。具体代码如下所示:
def get\_chunks\(
inputs: torch.Tensor,
chunksize: int = 2\*\*15
\) -> List\[torch.Tensor\]:
"""
输入分块。
"""
return \[inputs\[i:i + chunksize\] for i in range\(0, inputs.shape\[0\], chunksize\)\]
def prepare\_chunks\(
points: torch.Tensor,
encoding\_function: Callable\[\[torch.Tensor\], torch.Tensor\],
chunksize: int = 2\*\*15
\) -> List\[torch.Tensor\]:
"""
对点进行编码和分块,为 NeRF 模型做好准备。
"""
points = points.reshape\(\(-1, 3\)\)
points = encoding\_function\(points\)
points = get\_chunks\(points, chunksize=chunksize\)
return points
def prepare\_viewdirs\_chunks\(
points: torch.Tensor,
rays\_d: torch.Tensor,
encoding\_function: Callable\[\[torch.Tensor\], torch.Tensor\],
chunksize: int = 2\*\*15
\) -> List\[torch.Tensor\]:
r"""
对视图方向进行编码和分块,为 NeRF 模型做好准备。
"""
viewdirs = rays\_d / torch.norm\(rays\_d, dim=-1, keepdim=True\)
viewdirs = viewdirs\[:, None, ...\].expand\(points.shape\).reshape\(\(-1, 3\)\)
viewdirs = encoding\_function\(viewdirs\)
viewdirs = get\_chunks\(viewdirs, chunksize=chunksize\)
return viewdirs
def nerf\_forward\(
rays\_o: torch.Tensor,
rays\_d: torch.Tensor,
near: float,
far: float,
encoding\_fn: Callable\[\[torch.Tensor\], torch.Tensor\],
coarse\_model: nn.Module,
kwargs\_sample\_stratified: dict = None,
n\_samples\_hierarchical: int = 0,
kwargs\_sample\_hierarchical: dict = None,
fine\_model = None,
viewdirs\_encoding\_fn: Optional\[Callable\[\[torch.Tensor\], torch.Tensor\]\] = None,
chunksize: int = 2\*\*15
\) -> Tuple\[torch.Tensor, torch.Tensor, torch.Tensor, dict\]:
"""
计算一次前向传播
"""
# 设置参数
if kwargs\_sample\_stratified is None:
kwargs\_sample\_stratified = \{\}
if kwargs\_sample\_hierarchical is None:
kwargs\_sample\_hierarchical = \{\}
# 沿着每条射线的样本查询点。
query\_points, z\_vals = sample\_stratified\(
rays\_o, rays\_d, near, far, \*\*kwargs\_sample\_stratified\)
# 准备批次。
batches = prepare\_chunks\(query\_points, encoding\_fn, chunksize=chunksize\)
if viewdirs\_encoding\_fn is not None:
batches\_viewdirs = prepare\_viewdirs\_chunks\(query\_points, rays\_d,
viewdirs\_encoding\_fn,
chunksize=chunksize\)
else:
batches\_viewdirs = \[None\] \* len\(batches\)
# 稀疏模型流程。
predictions = \[\]
for batch, batch\_viewdirs in zip\(batches, batches\_viewdirs\):
predictions.append\(coarse\_model\(batch, viewdirs=batch\_viewdirs\)\)
raw = torch.cat\(predictions, dim=0\)
raw = raw.reshape\(list\(query\_points.shape\[:2\]\) + \[raw.shape\[-1\]\]\)
# 执行可微分体积渲染,重新合成 RGB 图像。
rgb\_map, depth\_map, acc\_map, weights = raw2outputs\(raw, z\_vals, rays\_d\)
outputs = \{
'z\_vals\_stratified': z\_vals
\}
if n\_samples\_hierarchical > 0:
# Save previous outputs to return.
rgb\_map\_0, depth\_map\_0, acc\_map\_0 = rgb\_map, depth\_map, acc\_map
# 对精细查询点进行分层抽样。
query\_points, z\_vals\_combined, z\_hierarch = sample\_hierarchical\(
rays\_o, rays\_d, z\_vals, weights, n\_samples\_hierarchical,
\*\*kwargs\_sample\_hierarchical\)
# 像以前一样准备输入。
batches = prepare\_chunks\(query\_points, encoding\_fn, chunksize=chunksize\)
if viewdirs\_encoding\_fn is not None:
batches\_viewdirs = prepare\_viewdirs\_chunks\(query\_points, rays\_d,
viewdirs\_encoding\_fn,
chunksize=chunksize\)
else:
batches\_viewdirs = \[None\] \* len\(batches\)
# 通过精细模型向前传递新样本。
fine\_model = fine\_model if fine\_model is not None else coarse\_model
predictions = \[\]
for batch, batch\_viewdirs in zip\(batches, batches\_viewdirs\):
predictions.append\(fine\_model\(batch, viewdirs=batch\_viewdirs\)\)
raw = torch.cat\(predictions, dim=0\)
raw = raw.reshape\(list\(query\_points.shape\[:2\]\) + \[raw.shape\[-1\]\]\)
# 执行可微分体积渲染,重新合成 RGB 图像。
rgb\_map, depth\_map, acc\_map, weights = raw2outputs\(raw, z\_vals\_combined, rays\_d\)
# 存储输出
outputs\['z\_vals\_hierarchical'\] = z\_hierarch
outputs\['rgb\_map\_0'\] = rgb\_map\_0
outputs\['depth\_map\_0'\] = depth\_map\_0
outputs\['acc\_map\_0'\] = acc\_map\_0
# 存储输出
outputs\['rgb\_map'\] = rgb\_map
outputs\['depth\_map'\] = depth\_map
outputs\['acc\_map'\] = acc\_map
outputs\['weights'\] = weights
return outputs
到这一步骤,就几乎拥有了训练模型所需的一切模块。现在为一个简单的训练过程做一些设置,创建超参数和辅助函数,然后来训练模型。
7.1 超参数
所有用于训练的超参数都在此设置,默认值取自原始论文中数据,除非计算上有限制。在计算受限情况下,本次讨论采用的都是合理的默认值。
# 编码器
d\_input = 3 # 输入维度
n\_freqs = 10 # 输入到编码函数中的样本点数量
log\_space = True # 如果设置,频率按对数空间缩放
use\_viewdirs = True # 如果设置,则使用视图方向作为输入
n\_freqs\_views = 4 # 视图编码功能的数量
# 采样策略
n\_samples = 64 # 每条射线的空间样本数
perturb = True # 如果设置,则对采样位置应用噪声
inverse\_depth = False # 如果设置,则按反深度线性采样点
# 模型
d\_filter = 128 # 线性层滤波器的尺寸
n\_layers = 2 # bottleneck层数量
skip = \[\] # 应用输入残差的层级
use\_fine\_model = True # 如果设置,则创建一个精细模型
d\_filter\_fine = 128 # 精细网络线性层滤波器的尺寸
n\_layers\_fine = 6 # 精细网络瓶颈层数
# 分层采样
n\_samples\_hierarchical = 64 # 每条射线的样本数
perturb\_hierarchical = False # 如果设置,则对采样位置应用噪声
# 优化器
lr = 5e-4 # 学习率
# 训练
n\_iters = 10000
batch\_size = 2\*\*14 # 每个梯度步长的射线数量(2 的幂次)
one\_image\_per\_step = True # 每个梯度步骤一个图像(禁用批处理)
chunksize = 2\*\*14 # 根据需要进行修改,以适应 GPU 内存
center\_crop = True # 裁剪图像的中心部分(每幅图像裁剪一次)
center\_crop\_iters = 50 # 经过这么多epoch后,停止裁剪中心
display\_rate = 25 # 每 X 个epoch显示一次测试输出
# 早停
warmup\_iters = 100 # 热身阶段的迭代次数
warmup\_min\_fitness = 10.0 # 在热身\_iters 处继续训练的最小 PSNR 值
n\_restarts = 10 # 训练停滞时重新开始的次数
# 捆绑了各种函数的参数,以便一次性传递。
kwargs\_sample\_stratified = \{
'n\_samples': n\_samples,
'perturb': perturb,
'inverse\_depth': inverse\_depth
\}
kwargs\_sample\_hierarchical = \{
'perturb': perturb
\}
7.2 训练类和函数
这一环节会创建一些用于训练的辅助函数。NeRF很容易出现局部最小值,在这种情况下,训练很快就会停滞并产生空白输出。必要时,会利用EarlyStopping重新启动训练。
# 绘制采样函数
def plot\_samples\(
z\_vals: torch.Tensor,
z\_hierarch: Optional\[torch.Tensor\] = None,
ax: Optional\[np.ndarray\] = None\):
r"""
绘制分层样本和(可选)分级样本。
"""
y\_vals = 1 + np.zeros\_like\(z\_vals\)
if ax is None:
ax = plt.subplot\(\)
ax.plot\(z\_vals, y\_vals, 'b-o'\)
if z\_hierarch is not None:
y\_hierarch = np.zeros\_like\(z\_hierarch\)
ax.plot\(z\_hierarch, y\_hierarch, 'r-o'\)
ax.set\_ylim\(\[-1, 2\]\)
ax.set\_title\('Stratified Samples \(blue\) and Hierarchical Samples \(red\)'\)
ax.axes.yaxis.set\_visible\(False\)
ax.grid\(True\)
return ax
def crop\_center\(
img: torch.Tensor,
frac: float = 0.5
\) -> torch.Tensor:
r"""
从图像中裁剪中心方形。
"""
h\_offset = round\(img.shape\[0\] \* \(frac / 2\)\)
w\_offset = round\(img.shape\[1\] \* \(frac / 2\)\)
return img\[h\_offset:-h\_offset, w\_offset:-w\_offset\]
class EarlyStopping:
r"""
基于适配标准的早期停止辅助器
"""
def \_\_init\_\_\(
self,
patience: int = 30,
margin: float = 1e-4
\):
self.best\_fitness = 0.0
self.best\_iter = 0
self.margin = margin
self.patience = patience or float\('inf'\) # 在epoch停止提高后等待的停止时间
def \_\_call\_\_\(
self,
iter: int,
fitness: float
\):
r"""
检查是否符合停止标准。
"""
if \(fitness - self.best\_fitness\) > self.margin:
self.best\_iter = iter
self.best\_fitness = fitness
delta = iter - self.best\_iter
stop = delta >= self.patience # 超过耐性则停止训练
return stop
def init\_models\(\):
r"""
为 NeRF 训练初始化模型、编码器和优化器。
"""
# 编码器
encoder = PositionalEncoder\(d\_input, n\_freqs, log\_space=log\_space\)
encode = lambda x: encoder\(x\)
# 视图方向编码
if use\_viewdirs:
encoder\_viewdirs = PositionalEncoder\(d\_input, n\_freqs\_views,
log\_space=log\_space\)
encode\_viewdirs = lambda x: encoder\_viewdirs\(x\)
d\_viewdirs = encoder\_viewdirs.d\_output
else:
encode\_viewdirs = None
d\_viewdirs = None
# 模型
model = NeRF\(encoder.d\_output, n\_layers=n\_layers, d\_filter=d\_filter, skip=skip,
d\_viewdirs=d\_viewdirs\)
model.to\(device\)
model\_params = list\(model.parameters\(\)\)
if use\_fine\_model:
fine\_model = NeRF\(encoder.d\_output, n\_layers=n\_layers, d\_filter=d\_filter, skip=skip,
d\_viewdirs=d\_viewdirs\)
fine\_model.to\(device\)
model\_params = model\_params + list\(fine\_model.parameters\(\)\)
else:
fine\_model = None
# 优化器
optimizer = torch.optim.Adam\(model\_params, lr=lr\)
# 早停
warmup\_stopper = EarlyStopping\(patience=50\)
return model, fine\_model, encode, encode\_viewdirs, optimizer, warmup\_stopper
7.3 训练循环
下面就是具体的训练循环过程函数:
def train\(\):
r"""
启动 NeRF 训练。
"""
# 对所有图像进行射线洗牌。
if not one\_image\_per\_step:
height, width = images.shape\[1:3\]
all\_rays = torch.stack\(\[torch.stack\(get\_rays\(height, width, focal, p\), 0\)
for p in poses\[:n\_training\]\], 0\)
rays\_rgb = torch.cat\(\[all\_rays, images\[:, None\]\], 1\)
rays\_rgb = torch.permute\(rays\_rgb, \[0, 2, 3, 1, 4\]\)
rays\_rgb = rays\_rgb.reshape\(\[-1, 3, 3\]\)
rays\_rgb = rays\_rgb.type\(torch.float32\)
rays\_rgb = rays\_rgb\[torch.randperm\(rays\_rgb.shape\[0\]\)\]
i\_batch = 0
train\_psnrs = \[\]
val\_psnrs = \[\]
iternums = \[\]
for i in trange\(n\_iters\):
model.train\(\)
if one\_image\_per\_step:
# 随机选择一张图片作为目标。
target\_img\_idx = np.random.randint\(images.shape\[0\]\)
target\_img = images\[target\_img\_idx\].to\(device\)
if center\_crop and i \< center\_crop\_iters:
target\_img = crop\_center\(target\_img\)
height, width = target\_img.shape\[:2\]
target\_pose = poses\[target\_img\_idx\].to\(device\)
rays\_o, rays\_d = get\_rays\(height, width, focal, target\_pose\)
rays\_o = rays\_o.reshape\(\[-1, 3\]\)
rays\_d = rays\_d.reshape\(\[-1, 3\]\)
else:
# 在所有图像上随机显示。
batch = rays\_rgb\[i\_batch:i\_batch + batch\_size\]
batch = torch.transpose\(batch, 0, 1\)
rays\_o, rays\_d, target\_img = batch
height, width = target\_img.shape\[:2\]
i\_batch += batch\_size
# 一个epoch后洗牌
if i\_batch >= rays\_rgb.shape\[0\]:
rays\_rgb = rays\_rgb\[torch.randperm\(rays\_rgb.shape\[0\]\)\]
i\_batch = 0
target\_img = target\_img.reshape\(\[-1, 3\]\)
# 运行 TinyNeRF 的一次迭代,得到渲染后的 RGB 图像。
outputs = nerf\_forward\(rays\_o, rays\_d,
near, far, encode, model,
kwargs\_sample\_stratified=kwargs\_sample\_stratified,
n\_samples\_hierarchical=n\_samples\_hierarchical,
kwargs\_sample\_hierarchical=kwargs\_sample\_hierarchical,
fine\_model=fine\_model,
viewdirs\_encoding\_fn=encode\_viewdirs,
chunksize=chunksize\)
# 检查任何数字问题。
for k, v in outputs.items\(\):
if torch.isnan\(v\).any\(\):
print\(f"\! \[Numerical Alert\] \{k\} contains NaN."\)
if torch.isinf\(v\).any\(\):
print\(f"\! \[Numerical Alert\] \{k\} contains Inf."\)
# 反向传播
rgb\_predicted = outputs\['rgb\_map'\]
loss = torch.nn.functional.mse\_loss\(rgb\_predicted, target\_img\)
loss.backward\(\)
optimizer.step\(\)
optimizer.zero\_grad\(\)
psnr = -10. \* torch.log10\(loss\)
train\_psnrs.append\(psnr.item\(\)\)
# 以给定的显示速率评估测试值。
if i \% display\_rate == 0:
model.eval\(\)
height, width = testimg.shape\[:2\]
rays\_o, rays\_d = get\_rays\(height, width, focal, testpose\)
rays\_o = rays\_o.reshape\(\[-1, 3\]\)
rays\_d = rays\_d.reshape\(\[-1, 3\]\)
outputs = nerf\_forward\(rays\_o, rays\_d,
near, far, encode, model,
kwargs\_sample\_stratified=kwargs\_sample\_stratified,
n\_samples\_hierarchical=n\_samples\_hierarchical,
kwargs\_sample\_hierarchical=kwargs\_sample\_hierarchical,
fine\_model=fine\_model,
viewdirs\_encoding\_fn=encode\_viewdirs,
chunksize=chunksize\)
rgb\_predicted = outputs\['rgb\_map'\]
loss = torch.nn.functional.mse\_loss\(rgb\_predicted, testimg.reshape\(-1, 3\)\)
print\("Loss:", loss.item\(\)\)
val\_psnr = -10. \* torch.log10\(loss\)
val\_psnrs.append\(val\_psnr.item\(\)\)
iternums.append\(i\)
# 绘制输出示例
fig, ax = plt.subplots\(1, 4, figsize=\(24,4\), gridspec\_kw=\{'width\_ratios': \[1, 1, 1, 3\]\}\)
ax\[0\].imshow\(rgb\_predicted.reshape\(\[height, width, 3\]\).detach\(\).cpu\(\).numpy\(\)\)
ax\[0\].set\_title\(f'Iteration: \{i\}'\)
ax\[1\].imshow\(testimg.detach\(\).cpu\(\).numpy\(\)\)
ax\[1\].set\_title\(f'Target'\)
ax\[2\].plot\(range\(0, i + 1\), train\_psnrs, 'r'\)
ax\[2\].plot\(iternums, val\_psnrs, 'b'\)
ax\[2\].set\_title\('PSNR \(train=red, val=blue'\)
z\_vals\_strat = outputs\['z\_vals\_stratified'\].view\(\(-1, n\_samples\)\)
z\_sample\_strat = z\_vals\_strat\[z\_vals\_strat.shape\[0\] // 2\].detach\(\).cpu\(\).numpy\(\)
if 'z\_vals\_hierarchical' in outputs:
z\_vals\_hierarch = outputs\['z\_vals\_hierarchical'\].view\(\(-1, n\_samples\_hierarchical\)\)
z\_sample\_hierarch = z\_vals\_hierarch\[z\_vals\_hierarch.shape\[0\] // 2\].detach\(\).cpu\(\).numpy\(\)
else:
z\_sample\_hierarch = None
\_ = plot\_samples\(z\_sample\_strat, z\_sample\_hierarch, ax=ax\[3\]\)
ax\[3\].margins\(0\)
plt.show\(\)
# 检查 PSNR 是否存在问题,如果发现问题,则停止运行。
if i == warmup\_iters - 1:
if val\_psnr \< warmup\_min\_fitness:
print\(f'Val PSNR \{val\_psnr\} below warmup\_min\_fitness \{warmup\_min\_fitness\}. Stopping...'\)
return False, train\_psnrs, val\_psnrs
elif i \< warmup\_iters:
if warmup\_stopper is not None and warmup\_stopper\(i, psnr\):
print\(f'Train PSNR flatlined at \{psnr\} for \{warmup\_stopper.patience\} iters. Stopping...'\)
return False, train\_psnrs, val\_psnrs
return True, train\_psnrs, val\_psnrs
最终的结果如下图所示:
6|运行结果示意图
引用:
[1]https://www.matthewtancik.com/nerf
[2]http://cseweb.ucsd.edu/\~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
[3]https://towardsdatascience.com/its-nerf-from-nothing-build-a-vanilla-nerf-with-pytorch-7846e4c45666
[4]https://medium.com/@rparikshat1998/nerf-from-scratch-fe21c08b145d
四、Pytorch~SimCLR
使用Pytorch实现对比学习SimCLR 进行自监督预训练
这里将深入研究 SimCLR 框架并探索该算法的关键组件,包括数据增强、对比损失函数以及编码器和投影的head 架构。
SimCLR(Simple Framework for Contrastive Learning of Representations)是一种学习图像表示的自监督技术。与传统的监督学习方法不同,SimCLR 不依赖标记数据来学习有用的表示。它利用对比学习框架来学习一组有用的特征,这些特征可以从未标记的图像中捕获高级语义信息。
SimCLR 已被证明在各种图像分类基准上优于最先进的无监督学习方法。并且它学习到的表示可以很容易地转移到下游任务,例如对象检测、语义分割和小样本学习,只需在较小的标记数据集上进行最少的微调。
SimCLR 主要思想是通过增强模块 T 将图像与同一图像的其他增强版本进行对比,从而学习图像的良好表示。这是通过通过编码器网络 f(.) 映射图像,然后进行投影来完成的。head g(.) 将学习到的特征映射到低维空间。然后在同一图像的两个增强版本的表示之间计算对比损失,以鼓励对同一图像的相似表示和对不同图像的不同表示。
我们这里使用来自 Kaggle 的垃圾分类数据集来进行实验。
增强模块
SimCLR 中最重要的就是转换图像的增强模块。SimCLR 论文的作者建议,强大的数据增强对于无监督学习很有用。因此,我们将遵循论文中推荐的方法。
- 调整大小的随机裁剪
- 50% 概率的随机水平翻转
- 随机颜色失真(颜色抖动概率为 80%,颜色下降概率为 20%)
-
50% 概率为随机高斯模糊
def get_complete_transform(output_shape, kernel_size, s=1.0):
"""
Color distortion transformArgs: s: Strength parameter Returns: A color distortion transform """ rnd_crop = RandomResizedCrop(output_shape) rnd_flip = RandomHorizontalFlip(p=0.5) color_jitter = ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) rnd_color_jitter = RandomApply([color_jitter], p=0.8) rnd_gray = RandomGrayscale(p=0.2) gaussian_blur = GaussianBlur(kernel_size=kernel_size) rnd_gaussian_blur = RandomApply([gaussian_blur], p=0.5) to_tensor = ToTensor() image_transform = Compose([ to_tensor, rnd_crop, rnd_flip, rnd_color_jitter, rnd_gray, rnd_gaussian_blur, ]) return image_transform
class ContrastiveLearningViewGenerator(object):
"""
Take 2 random crops of 1 image as the query and key.
"""
def init(self, base_transform, n_views=2):
self.base_transform = base_transform
self.n_views = n_viewsdef __call__(self, x): views = [self.base_transform(x) for i in range(self.n_views)] return views
下一步就是定义一个PyTorch 的 Dataset 。
class CustomDataset(Dataset):
def __init__(self, list_images, transform=None):
"""
Args:
list_images (list): List of all the images
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.list_images = list_images
self.transform = transform
def __len__(self):
return len(self.list_images)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.list_images[idx]
image = io.imread(img_name)
if self.transform:
image = self.transform(image)
return image
作为样例,我们使用比较小的模型 ResNet18 作为主干,所以他的输入是 224x224 图像,我们按照要求设置一些参数并生成dataloader
out_shape = [224, 224]
kernel_size = [21, 21] # 10% of out_shape
# Custom transform
base_transforms = get_complete_transform(output_shape=out_shape, kernel_size=kernel_size, s=1.0)
custom_transform = ContrastiveLearningViewGenerator(base_transform=base_transforms)
garbage_ds = CustomDataset(
list_images=glob.glob("/kaggle/input/garbage-classification/garbage_classification/*/*.jpg"),
transform=custom_transform
)
BATCH_SZ = 128
# Build DataLoader
train_dl = torch.utils.data.DataLoader(
garbage_ds,
batch_size=BATCH_SZ,
shuffle=True,
drop_last=True,
pin_memory=True)
SimCLR
我们已经准备好了数据,开始对模型进行复现。上面的增强模块提供了图像的两个增强视图,它们通过编码器前向传递以获得相应的表示。SimCLR 的目标是通过鼓励模型从两个不同的增强视图中学习对象的一般表示来最大化这些不同学习表示之间的相似性。编码器网络的选择不受限制,可以是任何架构。上面已经说了,为了简单演示,我们使用 ResNet18。编码器模型学习到的表示决定了相似性系数,为了提高这些表示的质量,SimCLR 使用投影头将编码向量投影到更丰富的潜在空间中。这里我们将ResNet18的512维度的特征投影到256的空间中,看着很复杂,其实就是加了一个带relu的mlp。
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class SimCLR(nn.Module):
def __init__(self, linear_eval=False):
super().__init__()
self.linear_eval = linear_eval
resnet18 = models.resnet18(pretrained=False)
resnet18.fc = Identity()
self.encoder = resnet18
self.projection = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 256)
)
def forward(self, x):
if not self.linear_eval:
x = torch.cat(x, dim=0)
encoding = self.encoder(x)
projection = self.projection(encoding)
return projection
对比损失
对比损失函数,也称为归一化温度标度交叉熵损失 (NT-Xent),是 SimCLR 的一个关键组成部分,它鼓励模型学习相同图像的相似表示和不同图像的不同表示。
NT-Xent 损失是使用一对通过编码器网络传递的图像的增强视图来计算的,以获得它们相应的表示。对比损失的目标是鼓励同一图像的两个增强视图的表示相似,同时迫使不同图像的表示不相似。
NT-Xent 将 softmax 函数应用于增强视图表示的成对相似性。softmax 函数应用于小批量内的所有表示对,得到每个图像的相似性概率分布。温度参数temperature 用于在应用 softmax 函数之前缩放成对相似性,这有助于在优化过程中获得更好的梯度。
在获得相似性的概率分布后,通过最大化同一图像的匹配表示的对数似然和最小化不同图像的不匹配表示的对数似然来计算 NT-Xent 损失。
LABELS = torch.cat([torch.arange(BATCH_SZ) for i in range(2)], dim=0)
LABELS = (LABELS.unsqueeze(0) == LABELS.unsqueeze(1)).float() #one-hot representations
LABELS = LABELS.to(DEVICE)
def ntxent_loss(features, temp):
"""
NT-Xent Loss.
Args:
z1: The learned representations from first branch of projection head
z2: The learned representations from second branch of projection head
Returns:
Loss
"""
similarity_matrix = torch.matmul(features, features.T)
mask = torch.eye(LABELS.shape[0], dtype=torch.bool).to(DEVICE)
labels = LABELS[~mask].view(LABELS.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(DEVICE)
logits = logits / temp
return logits, labels
所有的准备都完成了,让我们训练 SimCLR 看看效果!
simclr_model = SimCLR().to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(simclr_model.parameters())
epochs = 10
with tqdm(total=epochs) as pbar:
for epoch in range(epochs):
t0 = time.time()
running_loss = 0.0
for i, views in enumerate(train_dl):
projections = simclr_model([view.to(DEVICE) for view in views])
logits, labels = ntxent_loss(projections, temp=2)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print stats
running_loss += loss.item()
if i%10 == 9: # print every 10 mini-batches
print(f"Epoch: {epoch+1} Batch: {i+1} Loss: {(running_loss/100):.4f}")
running_loss = 0.0
pbar.update(1)
print(f"Time taken: {((time.time()-t0)/60):.3f} mins")
上面代码训练了10轮,假设我们已经完成了预训练过程,可以将预训练的编码器用于我们想要的下游任务。这可以通过下面的代码来完成。
from torchvision.transforms import Resize, CenterCrop
resize = Resize(255)
ccrop = CenterCrop(224)
ttensor = ToTensor()
custom_transform = Compose([
resize,
ccrop,
ttensor,
])
garbage_ds = ImageFolder(
root="/kaggle/input/garbage-classification/garbage_classification/",
transform=custom_transform
)
classes = len(garbage_ds.classes)
BATCH_SZ = 128
train_dl = torch.utils.data.DataLoader(
garbage_ds,
batch_size=BATCH_SZ,
shuffle=True,
drop_last=True,
pin_memory=True,
)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class LinearEvaluation(nn.Module):
def __init__(self, model, classes):
super().__init__()
simclr = model
simclr.linear_eval=True
simclr.projection = Identity()
self.simclr = simclr
for param in self.simclr.parameters():
param.requires_grad = False
self.linear = nn.Linear(512, classes)
def forward(self, x):
encoding = self.simclr(x)
pred = self.linear(encoding)
return pred
eval_model = LinearEvaluation(simclr_model, classes).to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.Adam(eval_model.parameters())
preds, labels = [], []
correct, total = 0, 0
with torch.no_grad():
t0 = time.time()
for img, gt in tqdm(train_dl):
image = img.to(DEVICE)
label = gt.to(DEVICE)
pred = eval_model(image)
_, pred = torch.max(pred.data, 1)
total += label.size(0)
correct += (pred == label).float().sum().item()
print(f"Time taken: {((time.time()-t0)/60):.3f} mins")
print(
"Accuracy of the network on the {} Train images: {} %".format(
total, 100 * correct / total
)
)
上面的代码最主要的部分就是读取刚刚训练的simclr模型,然后冻结所有的权重,然后再创建一个分类头self.linear ,进行下游的分类任务
总结
本文介绍了SimCLR框架,并使用它来预训练随机初始化权重的ResNet18。预训练是深度学习中使用的一种强大的技术,用于在大型数据集上训练模型,学习可以转移到其他任务中的有用特征。SimCLR论文认为,批量越大,性能越好。我们的实现只使用128个批大小,只训练10个epoch。所以这不是模型的最佳性能,如果需要性能对比还需要进一步的训练。
下图是论文作者给出的性能结论:
论文地址:https://arxiv.org/abs/2002.05709
五、PyTorch 和LipNet 实现唇语识别
LipReader 的核心是LipNet,这是一种专为唇读任务设计的深度学习模型。下面链接论文首次介绍了LipNet,它是一种处理视频数据并预测字符序列的时空模型。
https://arxiv.org/abs/1611.01599
数据集准备
LipReader 使用视频记录与文本注释配对的数据集。预处理涉及:
-
从每个视频文件中捕获帧(对于第一个发言者,所有视频包含 75 帧)
-
将帧转换为灰度以减少计算量。
-
调整框架大小以仅包含说话者的嘴唇区域。
-
对每帧进行标准化:[frame --- mean(frame)]/ std(frame)
-
拆分注释文件中的字符,同时忽略连续的空格
-
标签编码
-
统一每个批次中序列的长度
def capture_frames(video_path):
capture = cv.VideoCapture(video_path)
frames = []
number_of_frames = int(capture.get(cv.CAP_PROP_FRAME_COUNT))
for _ in range(number_of_frames):
ret, frame = capture.read()
frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
frames.append(frame[190 : 236, 100 : 200])
capture.release()
return np.array(frames)def reduce_frames(frames):
mean = np.mean(frames)
std = np.std(frames)
return torch.from_numpy((frames - mean) / std)def get_video_path(filename) :
filename = filename.decode("utf-8")
return os.path.join(path, f"{filename}.mpg")def load_video(filename):
video_path = get_video_path(filename)
frames = capture_frames(video_path)
return reduce_frames(frames)def get_annotation_path(filename) :
filename = filename.decode("utf-8")
return os.path.join("data", "align\s1", f"{filename}.align")def split_line(annotation_path):
with open(annotation_path, 'r') as f:
lines = f.readlines()
token = []
for line in lines:
words = line.split()
if words[2] != "sil":
token = [*token, " ", words[2]]
return tokendef split_words(line):
tokens = []
for word in line:
for character in word:
tokens.append(character)
return tokensdef load_annot(filename) :
annotation_path = get_annotation_path(filename)
line_split = split_line(annotation_path)
tokens = split_words(line_split)
return encode(tokens)def encode(data):
encoder = LabelEncoder(vocab, reserved_labels=['...'], unknown_index=-1)
return encoder.batch_encode(data)def decode(data):
encoder = LabelEncoder(vocab, reserved_labels=['...'], unknown_index=-1)
return encoder.batch_decode(data)
所用视频是原始论文中使用的视频的子集(出于内存限制,仅使用与第一位发言者相关的视频)。与原始论文不同,没有进行数据增强。数据可在下面链接中找到。
https://drive.google.com/uc?id=1YlvpDLix3S-U8fd-gqRwPcWXAXm8JwjL
模型架构
LipNet 的架构包括:
-
卷积层:这些神经网络层通常用于计算机视觉模型中,使用卷积提取空间特征。LipNet 使用三个时空卷积层,即 3D 卷积层,它们也可以通过时间维度提取特征。
-
GRU 层:门控循环单元 (GRU) 是一种 RNN,它通过添加门和单元来进一步传播信息,从而改进了早期的设计。双向 GRU (Bi-GRU) 双向传播信息流,因此在序列的所有时间步骤中都考虑依赖关系(例如,在我们的示例中,如果模型要预测单词"black":['b', 'l', 'a', 'c', 'k'] ,预测字符"a"取决于前一个时间步骤"b"和"l",也取决于序列的后续步骤"c"、"k")。LipNet 由两个 Bi-GRU 组成。
-
全连接层
-
CTC 损失:联结时间分类 (CTC) 损失在最近的语音识别应用中被广泛使用,因为它消除了训练中数据对齐的需要。在使用 CTC 损失时,必须确保输入序列的长度大于输出序列。
class LipNet(nn.Module):
def __init__(self): super(LipNet, self).__init__() self.conv1 = nn.Conv3d(in_channels= 1, out_channels= 32, kernel_size= (3, 5, 5), bias=False) self.pool1 = nn.MaxPool3d(kernel_size= (1, 2, 2), stride=(1, 2, 2)) self.conv2 = nn.Conv3d(in_channels= 32, out_channels= 64, kernel_size= (3, 5, 5), bias=False) self.pool2 = nn.MaxPool3d(kernel_size= (1, 2, 2), stride=(1, 2, 2)) self.conv3 = nn.Conv3d(in_channels= 64, out_channels= 96, kernel_size= (3, 3, 3), bias=False) self.pool3 = nn.MaxPool3d(kernel_size= (1, 2, 2), stride=(1, 2, 2)) self.gru1 = nn.GRU(input_size = 96 * 3 * 10,hidden_size =256, bidirectinotallow=True, batch_first=True) self.layer_norm1 = nn.LayerNorm(512) self.gru2 = nn.GRU(input_size = 512,hidden_size =256, bidirectinotallow=True, batch_first=True) self.layer_norm2 = nn.LayerNorm(512) self.dense = nn.Linear(in_features=512, out_features=41) def forward(self, x): #Shape of x is B, D, H, W, C x = x.permute(0, 4, 1, 2, 3) x = self.conv1(x) x = F.relu(x) x = self.pool1(x) x = F.relu(x) x = F.relu(self.conv2(x)) x = F.relu(self.pool2(x)) x = self.conv3(x) x = F.relu(x) x = F.relu(self.pool3(x)) b, c, d, h, w = x.size() # Batch, Channels, D, H, W x = x.permute(0, 2, 1, 3, 4) x = x.contiguous().view(b, d, -1) # X shape is (B, D, Channels * H * W) x, hidden = self.gru1(x) x = self.layer_norm1(x) x = F.relu(x) x, hidden = self.gru2(x) x = self.layer_norm2(x) x = F.relu(x) x = self.dense(x) return x
该模型使用 Adam 优化器和学习率调度器进行训练以确保收敛。
def initialize_weights(m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.GRU):
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.xavier_uniform_(param)
def train_for_one_epoch(model, optimizer, loss_fn):
model.train()
epoch_loss = 0
for i, data in enumerate(training_loader):
frames, labels = data
frames, labels = frames.to(device), labels.to(device)
tokens = model(frames)
B, target_D = labels.size()
B, D, C = tokens.size()
tokens = tokens.permute(1, 0, 2) # shape D, B, C
tokens = F.log_softmax(tokens, dim = 2)
input_lengths = torch.full((B,), D, dtype=torch.long)
target_lengths = torch.full((B,), target_D, dtype=torch.long)
assert(f"input lenghts is greater than the output length {input_lengths > target_lengths}")
loss = loss_fn(tokens, labels, input_lengths, target_lengths)
epoch_loss += loss.item()
optimizer.zero_grad() # Clear accumulated gradients
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step() # Update model weights
return epoch_loss / (i + 1)
def predict(model, x) :
return torch.argmax(model(x), axis = 2)
def validate_model(model, loss_fn):
model.eval()
running_vloss = 0
with torch.no_grad():
for i, vdata in enumerate(val_loader):
vframes, vlabels = vdata
vframes, vlabels = vframes.to(device), vlabels.to(device)
vout = model(vframes)
B, target_D = vlabels.size()
B, D, C = vout.size()
vout = vout.permute(1, 0, 2) # shape D, B, C
vout = F.log_softmax(vout, dim = 2)
vinput_lengths = torch.full((B,), D, dtype=torch.long)
vtarget_lengths = torch.full((B,), target_D, dtype=torch.long)
vloss = loss_fn(vout, vlabels, vinput_lengths, vtarget_lengths).item()
running_vloss += vloss
avg_vloss = running_vloss / (i + 1)
return avg_vloss, vout
def train(model, optimizer, loss_fn, n_epochs, scheduler):
history = {"epochs" : [], "traing_loss": [], "val_loss": []}
best_models = {"epoch" : [], "loss" : [], "predictions" : []}
best_v_loss = 1000
for epoch in range(n_epochs):
avg_loss = train_for_one_epoch(model, optimizer, loss_fn)
avg_vloss, last_pred = validate_model(model, loss_fn)
scheduler.step(avg_vloss)
history["epochs"].append(epoch +1)
history["traing_loss"].append(avg_loss)
history["val_loss"].append(avg_vloss)
print(f"EPOCH {epoch+1} : training loss : {avg_loss}, validation loss : {avg_vloss}")
if avg_vloss < best_v_loss :
best_v_loss = avg_vloss
best_models["epoch"].append(epoch + 1)
best_models["loss"].append(avg_vloss)
best_models["predictions"].append(ctc_decode(torch.argmax(last_pred, axis = 2)))
if (epoch + 1) % 25 ==0:
torch.save(model, f"LipAttention_EPOCH_{epoch+1}.pt")
return history, best_models
经过几个小时的训练,损失在训练和验证阶段都稳步下降,在 100 个 epoch 左右达到最低值。此后,模型开始过拟合。
构建应用程序
LipReader 使用 Streamlit 作为 Web 应用部署。主要功能包括:
- 视频输入选择,允许用户在预定义的数据集上测试模型。
- 模型处理的视频帧的可视化。
- 显示原始和解码的输出,展示应用程序的功能。
该应用程序的界面直观且用户友好,展示了模型的预测和原始注释。
#Import dependencies
import streamlit as st
import numpy as np
import torch
import os
import imageio
from data_preprocessing import load_video, decode, load_annot
from data_preprocessing import ctc_decode
from model import LipNet
from torchnlp.encoders import LabelEncoder
device = torch.device("cpu")
#Layout config
st.set_page_config(layout= "wide")
st.title("Lip Reader App")
if "loss" not in st.session_state:
st.session_state["loss"] = False
#Instantiate the sidebar
with st.sidebar:
st.image("computerVision.jpg")
st.info("""This application is an implemention of the LipNet paper. The goal is to create a computer vision model
(LipNet) that can read lips through videos.
Here is how two versions of the model perform on training and test data. You can find more information in
my article on this project on medium. If you want to check the evolution of the loss through epochs, click
on the button below.
Hope you enjoy it!!""")
def click_loss_button():
st.session_state.loss = True
def unclick_loss_button():
st.session_state.loss = False
col00 , col01 = st.columns(2)
with col00:
loss_butt= st.button("Show loss graph", on_click= click_loss_button)
with col01:
loss_butt= st.button("Hide loss graph", on_click= unclick_loss_button)
if st.session_state.loss :
st.image("../LossGraph.png")
# Save session states
if "split" not in st.session_state:
st.session_state["split"] = "training_files"
if "model" not in st.session_state:
st.session_state["selected_model"] = "LIPNET_100_EPOCHS.pt"
if "raw_button" not in st.session_state:
st.session_state["raw_button"] = False
if "decoded_button" not in st.session_state:
st.session_state["decoded_button"] = False
# Select videos
files = os.listdir(os.path.join("..", "data", "s1"))
split = ["training files", "validation files"]
st.session_state["split"] = st.selectbox("Choose between training and validation split", split)
n = len(files)
train_files = files[: int(0.9 * n)]
val_files = files[int(0.9 * n) : ]
options = {"training files" : train_files, "validation files" : val_files}
selected_option = st.selectbox(f"choose a video from {st.session_state.split}", options[st.session_state.split] )
col1 , col2 = st.columns(2)
if selected_option :
with col1 :
st.text("Selected video")
file_path = os.path.join("..", "data", "s1", selected_option)
os.system(f"ffmpeg -i {file_path} -vcodec libx264 selected_video.mp4 -y")
# Renedering video
video = open("selected_video.mp4", "rb")
video_bytes = video.read()
st.video(video_bytes)
filename = selected_option[ : -4]
labels_path = os.path.join("..","data", "align\s1", f"{filename}.align")
labels = load_annot(labels_path, True)
delim = ""
st.text(f"The true labels are : {delim.join(decode(labels))}")
with col2 :
#select model
st.info("LIPNET_100_EPOCHS performs the best on test data")
models = os.listdir(os.path.join("..", "models"))
st.session_state.selected_model = st.selectbox("Choose a model", models)
st.info("This is what the ML model sees")
video = load_video(file_path, from_path=True)
#To GIF
frames_np = video.numpy()
frames_scaled = ((frames_np - frames_np.min()) / (frames_np.max() - frames_np.min()) * 255).astype(np.uint8)
frames_list = [frame for frame in frames_scaled]
imageio.mimsave("animation.gif", frames_list, fps = 10)
st.image("animation.gif", width = 400)
#Preprocess frames
frames = video.float()
frames = frames.unsqueeze(0)
frames = frames.permute(1, 2, 3, 0)
frames = frames.unsqueeze(0)
#Load model
st.info("This is the output of the model")
model = torch.load(f"../models/{st.session_state.selected_model}", weights_notallow= False).to(device)
y_hat = model(frames)
y_pred = torch.argmax(y_hat, axis = 2)
y_pred = y_pred.squeeze(dim = 0)
#st.text(decode(y_pred))
def click_raw_button():
st.session_state.raw_button = True
def unclick_raw_button():
st.session_state.raw_button = False
col3 , col4 = st.columns(2)
with col3:
raw_butt= st.button("Show raw model outputs", on_click= click_raw_button)
with col4:
raw_butt= st.button("Hide raw model outputs", on_click= unclick_raw_button)
if st.session_state.raw_button:
st.text(y_pred)
#Decode output
vocab = [x for x in "abcdefghijklmnopqrstuvwxyz'?!123456789 "]
encoder = LabelEncoder(vocab, reserved_labels=['...'], unknown_index=-1)
# Inspect the individual tensors being passed
st.info("This is the decoded output ")
y_pred_ctc = ctc_decode(y_pred)
delim = ""
seq = delim.join(decode(y_pred_ctc))
col5 , col6 = st.columns(2)
def click_decoded_button():
st.session_state.decoded_button = True
def unclick_decoded_button():
st.session_state.decoded_button = False
with col5:
decoded_butt = st.button("Show decoded model output", on_click= click_decoded_button)
with col6:
decoded_butt = st.button("Hide decoded model output", on_click= unclick_decoded_button)
if st.session_state.decoded_button :
st.text(seq)
克服挑战
该项目面临多项挑战:
- 数据集大小:处理大型视频数据集需要高效的预处理和数据加载流程。
- 模型优化:确保高精度的同时保持合理的推理速度至关重要。
- 部署:构建响应迅速、可访问的应用程序界面需要仔细集成后端和前端组件。
- CTC 损失:在训练模型时,损失值有些随机,并且通常会很快发散(如果不是在训练开始时,则在第 4 个时期左右)。即使学习率较低且梯度剪裁。我能找到的唯一解决方案(在编程论坛上找到)是将批处理大小设置为 1。
可能的改进
有几种可能的改进方法。首先想到的是:
- 在更大、更多样化的数据集上训练模型。
- 探索基于变压器的架构以获得更好的性能。
- 扩展应用程序以支持多语言唇读。
- 尝试用不同的层(LSTM、注意力模块等)替换 Bi-GRU
代码下载:
https://github.com/Aym98/LipNet
六、Pytorch~einsum
本文将带你感受einsum的"万能",作者通过提供从基础到高级的einsum使用范例,展示了它是怎么做到既简洁又优雅地实现多种张量操作,并轻易解决维度匹配问题。einsum is all you needed!
如果问pytorch中最强大的一个数学函数是什么?
我会说是torch.einsum:爱因斯坦求和函数。它几乎是一个"万能函数":能实现超过一万种功能的函数。
不仅如此,和其它pytorch中的函数一样,torch.einsum是支持求导和反向传播的,并且计算效率非常高。
einsum 提供了一套既简洁又优雅的规则,可实现包括但不限于:内积,外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练掌握 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。
尤其是在一些包括batch维度的高阶张量的相关计算中,若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题,但换成einsum则会特别简单。
套用一句深度学习paper标题当中非常时髦的话术,einsum is all you needed !
本文源码路径:
einsum规则原理
顾名思义,einsum这个函数的思想起源于家喻户晓的小爱同学:爱因斯坦。
很久很久以前,小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。比如描述时空有一个四维时空度规张量,描述电磁场有一个电磁张量,描述运动的有能量动量张量。
在理论物理学家中,小爱同学的数学基础不算特别好,在捣鼓这些张量的时候,他遇到了一个比较头疼的问题:公式太长太复杂了。有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢,能不能减少一些那种扭曲的求和符号呢?
小爱发现,求和导致维度收缩,因此求和符号操作的指标总是只出现在公式的一边。例如在我们熟悉的矩阵乘法中
k这个下标被求和了,求和导致了这个维度的消失,所以它只出现在右边而不出现在左边。这种只出现在张量公式的一边的下标被称之为哑指标,反之为自由指标。
小爱同学脑瓜子滴溜一转,反正这种只出现在一边的哑指标一定是被求和求掉的,干脆把对应的求和符号省略得了。
这就是爱因斯坦求和约定:
只出现在公式一边的指标叫做哑指标,针对哑指标的求和符号可以省略。
公式立刻清爽了很多。
公式展现形式中除了省去了求和符号,还省去了乘法符号(代数通识)。
借鉴爱因斯坦求和约定表达张量运算的清爽整洁,numpy、tensorflow和 torch等库中都引入了 einsum这个函数。
上述矩阵乘法可以被einsum这个函数表述成
C = torch.einsum("ik,kj->ij",A,B)
这个函数的规则原理非常简洁,3句话说明白。
- 1,用元素计算公式来表达张量运算。
- 2,只出现在元素计算公式箭头左边的指标叫做哑指标。
- 3,省略元素计算公式中对哑指标的求和符号。
import torch
A = torch.tensor([[1,2],[3,4.0]])
B = torch.tensor([[5,6],[7,8.0]])
C1 = A@B
print(C1)
C2 = torch.einsum("ik,kj->ij",[A,B])
print(C2)
tensor([[19., 22.],
[43., 50.]])
tensor([[19., 22.],
[43., 50.]])
einsum基础范例
einsum这个函数的精髓实际上是第一条:用元素计算公式来表达张量运算。
而绝大部分张量运算都可以用元素计算公式很方便地来表达,这也是它为什么会那么神通广大。
例1,张量转置
#例1,张量转置
A = torch.randn(3,4,5)
#B = torch.permute(A,[0,2,1])
B = torch.einsum("ijk->ikj",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([3, 4, 5])
after: torch.Size([3, 5, 4])
例2,取对角元
#例2,取对角元
A = torch.randn(5,5)
#B = torch.diagonal(A)
B = torch.einsum("ii->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([5, 5])
after: torch.Size([5])
例3,求和降维
#例3,求和降维
A = torch.randn(4,5)
#B = torch.sum(A,1)
B = torch.einsum("ij->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([4, 5])
after: torch.Size([4])
例4,哈达玛积
#例4,哈达玛积
A = torch.randn(5,5)
B = torch.randn(5,5)
#C=A*B
C = torch.einsum("ij,ij->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 5]) torch.Size([5, 5])
after: torch.Size([5, 5])
例5,向量内积
#例5,向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([10])
after: torch.Size([])
例6,向量外积
#例6,向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([5])
after: torch.Size([10, 5])
例7,矩阵乘法
#例7,矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 4]) torch.Size([4, 6])
after: torch.Size([5, 6])
例8,张量缩并
#例8,张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([3, 4, 5]) torch.Size([4, 3, 6])
after: torch.Size([5, 6])
einsum高级范例
einsum可用于超过两个张量的计算。
例9,bilinear注意力机制
例如:双线性变换。这是向量内积的一种扩展,一种常用的注意力机制实现方式
不考虑batch维度时,双线性变换的公式如下:
#例9,bilinear注意力机制
#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(10) #key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features
#a = q@W@k.t()+b
a = torch.bilinear(q,k,W,b)
print("a.shape:",a.shape)
#=====考虑batch维度====
Q = torch.randn(8,10) #batch_size,query_features
K = torch.randn(8,10) #batch_size,key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features
#A = torch.bilinear(Q,K,W,b)
A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)
a.shape: torch.Size([5])
A.shape: torch.Size([8, 5])
例10,scaled-dot-product注意力机制
我们也可以用einsum来实现更常见的scaled-dot-product 形式的 Attention.
不考虑batch维度时, scaled-dot-product形式的Attention用矩阵乘法公式表示如下:
#例10,scaled-dot-product注意力机制
#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(6,10) #key_size, key_features
d_k = k.shape[-1]
a = torch.softmax(q@k.t()/d_k,-1)
print("a.shape=",a.shape )
#====考虑batch维度====
Q = torch.randn(8,10) #batch_size,query_features
K = torch.randn(8,6,10) #batch_size,key_size,key_features
d_k = K.shape[-1]
A = torch.softmax(torch.einsum("in,ijn->ij",Q,K)/d_k,-1)
print("A.shape=",A.shape )
a.shape= torch.Size([6])
A.shape= torch.Size([8, 6])
七、FFN-pytorch
针对视频识别的通用Once-For-All框架
https://arxiv.org/abs/2303.14817
相比于传统视频识别对不同帧数输入的分别训练,我们提供了一种解决方案:在单次训练的情况下,使模型能够在推理的时候根据输入帧数的变化动态调节计算量并表现出更高的准确率,同时显著地减少保存多个模型的参数量。我们提供了一个支持2D, 3D, Transformer网络的视频识别代码库,里面也包含了我们的预训练模型,欢迎大家交流和试用。
■ https://github.com/BeSpontaneous/FFN-pytorch
TL, DR:
**◆ 动机:**视频识别通常会采样多帧图像来代表整个视频。现有的视频识别算法总是对具有不同帧数的输入分别进行训练,这需要重复的训练操作和成倍的存储成本。
**◆ 观察:**如果我们在模型推理的时候使用训练未用到的帧数,模型性能则会显著下降(见下图),这被总结为时域频率偏移现象。
图1 时域频率偏移现象
**◆ 解决:**我们提出了一个通用的框架,名为Frame Flexible Network(FFN),它不仅可以使模型根据输入帧数的不同从而动态地调节计算量,还可以显著减少存储多个模型的内存成本。
◆ 优点:(1)一次性训练(2)明显的性能增益(3)参数量的显著节省(4)计算量的动态调整(5)强大的兼容性。
▌Introduction
越来越多的在线视频推动了视频识别研究的发展。与图像相关的任务不同,我们需要采样多帧的图像来表示整个视频,并且计算成本将与采样的帧数成比例增长。具体而言,现有方法将相同的网络对不同帧的输入分别进行训练,以获得具有不同性能和计算量的多个模型。这给将这些网络应用于边缘设备带来了挑战,因为如果我们存储所有模型,参数将会成倍的增加。
将模型在高帧数进行训练,然后直接在较少帧数的输入上进行推理以调整计算量是一种简单而直接的解决方案。为测试其有效性,我们将其与Separated Training(ST)进行了比较,并从图1中发现推理结果与ST之间存在明显的性能差距,这意味着如果将训练中未曾使用的帧数用于推理,这些方法将表现出明显劣于ST的性能。进一步地,我们在不同深度的深度网络上进行了相同的实验,出现了类似的现象。我们将这种普遍存在的现象称为时域频率偏移。
图2 Nearby Alleviation
为进一步验证时域频率偏移,我们分别在8帧和12帧的输入上训练了两个TSM[1]模型,然后将其在4/8/12/16帧分别进行推理测试。由图2可以看到,推理结果与ST的差异大小会因为帧数不同而有所不同,如果推理帧数接近于训练帧数,性能差距会更小,我们将这个现象称之为Nearby Alleviation。
图3 Normalization Shifting
由于Batch Normalization(BN)会在神经网络的各个层统计feature map的固有属性,如均值,方差,我们分别统计了在4帧和16帧输入上训练的两个TSM模型各个层的feature map的均值和方差。通过图3可知,4帧和16帧输入所对应的feature map的统计值存在一定的差异,也就意味着模型在推理的时候使用其他帧数的输入会导致Normalization Shifting,这是我们认为造成时域频率偏移现象的主要原因。
图4 Frame Flexible Network
▌Method
基于上述观察,我们提出了一个解决时域频率偏移现象的通用框架------Frame Flexible Network(FFN),如图4所示。由于我们的目的是让网络能够在任意帧数推理的时候表现出与Separated Training(ST)相似或者更好的性能,所以我们在训练的时候引入多个视频序列(对同一个视频采样不同帧数得到)并建立对应的子网络,然后在模型推理的时候根据输入帧数的大小来激活对应的子网络,防止Normalization Shifting的产生。
虽然这是一个比较直观的解决方案,但是具体实施上还存在多个问题:
(1) 如何构建子网络?
Slimmable Neural Network[2]曾给出过一种思路,即调整网络的宽度来分配计算资源,同理我们也可以给不同宽度的子网络分配帧数不同的视频序列。这种方法对于卷积结构的网络实现较为容易,但要拓展到Transformer架构的网络可能仍需进一步的适配。考虑到方法的兼容性和拓展性,最后我们没有采用这个方案。另一个直观的思路是直接将原来的网络进行复制来构建子网络,这样可以保证模型的性能与ST相当,但是参数量相比ST保存多个模型并没有减少,也就失去了Once-For-All的意义。
图5 Specific Design
对于任意视频识别网络,我们可以将其拆分为两部分,一部分是用于空间-时序建模的模块,这包括卷积,多头自注意力,多层感知机等,这部分模块占据了网络99%以上的参数;剩下的部分就是各种Normalization的操作,包括BN,LN等等,而这部分的参数量通常小于1%。由之前的分析可知,Normalization Shifting是导致时域频率偏移的主要原因,因此我们可以共享空间-时序建模模块的参数,而对每个子网络分配不同的Normalization操作,这样就可以在解决时域频率偏移问题的同时,维持总网络的参数量跟原来单个网络相当。相比于ST分别训练和保存多个模型,我们也就实现了参数的显著减少。
(2) 如何保证各个子网络的表达能力?
由于我们共享了空间-时序建模模块的参数,随之而来的问题就是如何保证这部分共享参数的表达能力,即在不同帧数输入的情况下均维持良好的性能。受之前自监督训练在视频识别上应用[3][4]的启发,我们认为,一个好的视频表征应该是对帧数变化无关的(Temporal Frequency Invariant),因此我们希望模型能够学到与帧数变化无关的表达(Temporal Frequency Invariant Representation)。如果从自监督的角度来理解,视频的采样过程可以被看作是一种在时序上的随机掩码(数据增强),采样的帧数越多,掩码越少,得到的视频序列越近于视频本身,性能也就越好;而基于对比学习的自监督的一个核心部分就是让模型学习关于各种变换(数据增强)的不变表达(Invariant Representation),即拉近正样本之间的距离,所以我们认为共享参数不仅不会限制模型的表达能力,反而会让其表现出优于ST的性能。
为了进一步增强模型的表达能力,我们提出了Weight Alteration------一个简单的深度可分离卷积层,并将其插入到各个子网络的不同stage当中,这样我们就可以通过一个简单的线性变换使得各个子网络拥有一套属于自己的独特参数。值得注意的是,视频识别经常会用到预训练模型,为了避免新加入的卷积层破坏原有的计算图,我们引入了residual[5]的结构并将卷积层的初始参数设置为较小值。
(3) 如何保证训练中未出现帧数的推理性能?
通过上述的一系列设计,我们基本可以保证模型在L, M和H三个帧数输入的推理性能,但是要如何保证训练中未出现帧数的推理性能仍是一个问题。由Nearby Alleviation可知,如果推理帧数接近于训练帧数,两者的性能差距会比较小,所以我们大概认为在L, M和H周边的输入也能维持较好的性能。基于上述考虑,我们提出了Inference at Any Frame的策略,即给定推理的帧数n,我们会比较n与L, M, H的距离,并选择与其距离最小的子网络进行激活,来得到对应的输出。
▌Experiment
表1 与baseline方法的比较
由表1可知,Mixed Sampling和Proportional Sampling可以在一定程度上缓解时域频率偏移现象,但是同时在16帧推理的准确率也会有一定的损失。Fine-tuning可以让模型在4帧推理时表现出比ST更好的性能,但是在其他帧的准确率会明显下降。Ensemble在8帧推理时有较高的准确率,但代价是三倍的参数量和计算量。相比于上述方法,FFN在任意帧的推理准确率都明显高于ST,并且参数量也是明显少于ST,这让FFN在边缘设备上的部署充满优势。
此外,我们还验证了FFN在不同架构(2D/3D/Transformer),数据集(Sth-Sth V2/Kinetics400/HMDB51)下的泛化能力,结果如图6,图7所示,可以看到FFN具有良好的泛化性和扩展性,并且能够在性能很好的基准模型Uniformer[7]上面仍取得非常显著的提升。
图6 在不同架构上的验证
图7 在不同数据集上的验证
为了验证Inference at Any Frame的策略,我们将FFN在2-20帧的所有偶数帧进行了推理评估,对于4-16帧内部的结果,我们称之为Inbound Results,因为训练所用到的最低帧是4帧,最高帧是16帧;而2/18/20帧的结果则称之为Outbound Results。由图8可知,FFN能够在所有帧的推理结果上全面超越ST,ST需要重复10次训练过程得到10个训练模型,而我们只需训练一次,得到模型的参数量与ST单次训练得到的模型参数量相当。
图8 Inference At Any Frame
最后是消融实验,我们分别对文章中的一系列设计进行了消融分析。比较有意思的一点是,当我们不进行参数共享,而对每个子网络分配其独有的参数时,模型的性能相较于共享参数会有显著的下降,同时也带来了成倍的参数量。这一点其实在之前的章节已经提到,独享的参数会让各个子网络的性能上限接近于ST,而参数共享可以让模型学习到与帧数变化无关的表达(Temporal Frequency Invariant Representation),从而表现出比ST更优的性能。
表2 消融实验
▌Conclusion
这个工作首先切实地解决了视频识别领域中的一个实际问题,对于模型在边缘设备的部署会有比较大的帮助;其次整体的设计比较简洁,FFN也具有较好的泛化性和拓展性。
最后欢迎大家follow我们的工作,我们提供了一个支持2D, 3D, Transformer网络的视频识别代码库,目前FFN支持TSM/TEA/SlowFast/Uniformer,欢迎大家来contribute~
@article{zhang2023frame, title={Frame Flexible Network}, author={Zhang, Yitian and Bai, Yue and Liu, Chang and Wang, Huan and Li, Sheng and Fu, Yun}, journal={arXiv preprint arXiv:2303.14817}, year={2023}
八、Pytorch-lightning
Pytorch-lightning可以非常简洁得构建深度学习代码。但是其实大部分人用不到很多复杂得功能,并且用的时候稍微有一些不灵活。
Pytorch-lightning(以下简称pl)可以非常简洁得构建深度学习代码。但是其实大部分人用不到很多复杂得功能。而pl有时候包装得过于深了,用的时候稍微有一些不灵活。通常来说,在你的模型搭建好之后,大部分的功能都会被封装在一个叫trainer的类里面。一些比较麻烦但是需要的功能通常如下:
- 保存checkpoints
- 输出log信息
- resume training 即重载训练,我们希望可以接着上一次的epoch继续训练
- 记录模型训练的过程(通常使用tensorboard)
- 设置seed,即保证训练过程可以复制
好在这些功能在pl中都已经实现。
由于doc上的很多解释并不是很清楚,而且网上例子也不是特别多。下面分享一点我自己的使用心得。
首先关于设置全局的种子:
from pytorch_lightning import seed_everything
# Set seed
seed = 42
seed_everything(seed)
只需要import如上的seed_everything函数即可。它应该和如下的函数是等价的:
def seed_all(seed_value):
random.seed(seed_value) # Python
np.random.seed(seed_value) # cpu vars
torch.manual_seed(seed_value) # cpu vars
if torch.cuda.is_available():
print ('CUDA is available')
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # gpu vars
torch.backends.cudnn.deterministic = True #needed
torch.backends.cudnn.benchmark = False
seed=42
seed_all(seed)
但经过我的测试,好像pl的seed_everything函数应该更全一点。
下面通过一个具体的例子来说明一些使用方法:
先下载、导入必要的包和下载数据集:
!pip install pytorch-lightning
!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip -q hymenoptera_data.zip
!rm hymenoptera_data.zip
import pytorch_lightning as pl
import os
import numpy as np
import random
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
以下代码种加入!的代码是在terminal中运行的。在google colab中运行linux命令需要在之前加!
如果是使用google colab,由于它创建的是一个虚拟机,不能及时保存,所以如果需要保存,挂载自己google云盘也是有必要的。使用如下的代码:
from google.colab import drive
drive.mount('./content/drive')
import os
os.chdir("/content/drive/My Drive/")
先如下定义如下的LightningModule和main函数。
class CoolSystem(pl.LightningModule):
def __init__(self, hparams):
super(CoolSystem, self).__init__()
self.params = hparams
self.data_dir = self.params.data_dir
self.num_classes = self.params.num_classes
########## define the model ##########
arch = torchvision.models.resnet18(pretrained=True)
num_ftrs = arch.fc.in_features
modules = list(arch.children())[:-1] # ResNet18 has 10 children
self.backbone = torch.nn.Sequential(*modules) # [bs, 512, 1, 1]
self.final = torch.nn.Sequential(
torch.nn.Linear(num_ftrs, 128),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(128, self.num_classes),
torch.nn.Softmax(dim=1))
def forward(self, x):
x = self.backbone(x)
x = x.reshape(x.size(0), -1)
x = self.final(x)
return x
def configure_optimizers(self):
# REQUIRED
optimizer = torch.optim.SGD([
{'params': self.backbone.parameters()},
{'params': self.final.parameters(), 'lr': 1e-2}
], lr=1e-3, momentum=0.9)
exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
return [optimizer], [exp_lr_scheduler]
def training_step(self, batch, batch_idx):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
_, preds = torch.max(y_hat, dim=1)
acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)
self.log('train_loss', loss)
self.log('train_acc', acc)
return {'loss': loss, 'train_acc': acc}
def validation_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
_, preds = torch.max(y_hat, 1)
acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)
self.log('val_loss', loss)
self.log('val_acc', acc)
return {'val_loss': loss, 'val_acc': acc}
def test_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
_, preds = torch.max(y_hat, 1)
acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)
return {'test_loss': loss, 'test_acc': acc}
def train_dataloader(self):
# REQUIRED
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
train_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'train'), transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
return train_loader
def val_dataloader(self):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=4)
return val_loader
def test_dataloader(self):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=8, shuffle=True, num_workers=4)
return val_loader
def main(hparams):
model = CoolSystem(hparams)
trainer = pl.Trainer(
max_epochs=hparams.epochs,
gpus=1,
accelerator='dp'
)
trainer.fit(model)
下面是run的部分:
from argparse import Namespace
args = {
'num_classes': 2,
'epochs': 5,
'data_dir': "/content/hymenoptera_data",
}
hyperparams = Namespace(**args)
if __name__ == '__main__':
main(hyperparams)
如果希望重载训练的话,可以按如下方式:
# resume training
RESUME = True
if RESUME:
resume_checkpoint_dir = './lightning_logs/version_0/checkpoints/'
checkpoint_path = os.listdir(resume_checkpoint_dir)[0]
resume_checkpoint_path = resume_checkpoint_dir + checkpoint_path
args = {
'num_classes': 2,
'data_dir': "/content/hymenoptera_data"}
hparams = Namespace(**args)
model = CoolSystem(hparams)
trainer = pl.Trainer(gpus=1,
max_epochs=10,
accelerator='dp',
resume_from_checkpoint = resume_checkpoint_path)
trainer.fit(model)
如果我们想要从checkpoint加载模型,并进行使用可以按如下操作来:
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(inp):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
plt.show()
classes = ['ants', 'bees']
checkpoint_dir = 'lightning_logs/version_1/checkpoints/'
checkpoint_path = checkpoint_dir + os.listdir(checkpoint_dir)[0]
checkpoint = torch.load(checkpoint_path)
model_infer = CoolSystem(hparams)
model_infer.load_state_dict(checkpoint['state_dict'])
try_dataloader = model_infer.test_dataloader()
inputs, labels = next(iter(try_dataloader))
# print images and ground truth
imshow(torchvision.utils.make_grid(inputs))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(8)))
# inference
outputs = model_infer(inputs)
_, preds = torch.max(outputs, dim=1)
# print (preds)
print (torch.sum(preds == labels.data) / (labels.shape[0] * 1.0))
print('Predicted: ', ' '.join('%5s' % classes[preds[j]] for j in range(8)))
预测结果如上。
如果希望检测训练过程(第一部分+重载训练的部分),如下:
# tensorboard
%load_ext tensorboard
%tensorboard --logdir = ./lightning_logs
训练过程在tensorboard里面记录,version0是第一次的训练,version1是重载后的结果。
完整的code在这里.
九、PyTorch~CUTLASS Ping-Pong GEMM Kernel 简介
详解CUTLASS Ping-Pong GEMM Kernel的设计和性能。
博客来源:https://pytorch.org/blog/cutlass-ping-pong-gemm-kernel/ 这里做了个翻译。这篇PyTorch的blog简要介绍了 CUTLASS 中的 Ping-Pong GEMM kernel 设计,它是专门为 Hopper GPU 架构优化的高性能矩阵乘法实现。通过采用生产者-消费者模式的异步流水线设计,结合 TMA 硬件加速和专门化的 warp 组,实现了对 Tensor Core 的高效利用。文章通过benchmark表明,这种设计相比 cuBLAS 和 Triton 等其他实现具有明显优势,充分展现了在新一代 GPU 架构上如何通过深度异步化来最大化计算吞吐量。同时也把这部分cutlass代码单独抽出来写了一个PyTorch可以用的扩展,见 https://github.com/pytorch-labs/applied-ai/tree/main/kernels/cuda/cutlass_gemm 。
深入解析 CUTLASS Ping-Pong GEMM kernel
图1. FP8 GEMM 吞吐量对比:CUTLASS vs Triton
摘要
在这篇文章中,我们将概述 CUTLASS Ping-Pong GEMM kernel ,并提供相关的 FP8 推理kernel 基准测试。
Ping-Pong 是 Hopper GPU 架构上可用的最快矩阵乘法(GEMM)kernel 架构之一。Ping-Pong 属于 Warp Group Specialized Persistent Kernels 家族,该家族包括 Cooperative 和 Ping-Pong 两种变体。相对于之前的 GPU,Hopper 的强大Tensor Core计算能力需要深度异步软件流水线来实现峰值性能。
Ping-Pong 和 Cooperative kernel 很好地展示了这一范式,其关键设计模式是持久化kernel (用于摊销启动和prologue开销),以及"全面异步"的专门化 warp 组(包含两个消费者和一个生产者),以创建高度重叠的处理流水线,能够持续为Tensor Core提供数据。
当 H100 (Hopper) GPU 发布时,Nvidia 将其称为第一款真正的异步 GPU。这一说法突出了 H100 特定kernel 架构也需要异步化,以充分最大化计算/GEMM 吞吐量。
CUTLASS 3.x 中引入的 pingpong GEMM 通过将kernel 的所有方面都移至"完全异步"处理范式来体现这一点。在这篇博客中,我们将展示 ping-pong kernel 设计的核心特性,并展示其在推理工作负载上与 cublas 和 triton split-k kernel 相比的性能。
Ping-Pong kernel 设计
Ping-Pong(或者从技术上说是'sm90_gemm_tma_warpspecialized_pingpong')采用异步流水线运行,利用 warp 专门化。与传统的同质kernel 不同,"warp 组"承担专门的角色。需要注意的是,一个 warp 组由 4 个 warp(每个 warp 32 个线程)组成,总共 128 个线程。
在早期架构中,通常通过在每个 SM 上运行多个线程块来隐藏延迟。然而,在 Hopper 上,Tensor Core吞吐量如此之高,以至于需要转向更深的流水线。这些更深的流水线会阻碍在每个 SM 上运行多个线程块。因此,持久化线程块现在会在多个Tile 和多个 warp 组之间发出集体main loops 。线程块集群根据总 SM 数量进行分配。
对于 Ping-Pong 来说,每个 warp 组都承担数据生产者或数据消费者的专门角色。
生产者 warp 组专注于通过 TMA 产生数据移动来填充共享内存缓冲区。另外两个 warp 组是专门的消费者,它们处理使用Tensor Core的数学(MMA)部分,然后进行任何后续工作并将结果写回全局内存(epilogue)。
生产者 warp 组与 TMA(张量内存加速器)一起工作,并被刻意保持尽可能轻量级。事实上,在 Ping-Pong 中,它们故意减少寄存器资源以提高占用率。生产者将其最大寄存器数减少 40 个,而消费者将其最大寄存器数增加 232 个,这种效果我们可以在 CUTLASS 源代码和相应的 SASS 中看到:
Ping-Pong 的独特之处在于,每个消费者在不同的 C 输出Tile 上工作。(作为参考,cooperative kernel 在很大程度上等同于 Ping-Pong,但两个消费者组在同一个 C 输出Tile 上工作)。此外,两个消费者 warp 组然后在main loops MMA 和 epilogue 之间分配它们的工作。
这在下图中显示:
图2:Ping-Pong kernel 流水线概览。时间从左向右移动。
通过拥有两个消费者,意味着一个可以使用Tensor Core进行 MMA,而另一个执行 epilogue,然后反之亦然。这最大化了每个 SM 上Tensor Core的"连续使用",这是实现最大吞吐量的关键原因之一。Tensor Core可以持续获得数据以实现(接近)最大计算能力。(参见上图 Fig 2 的底部部分)。
与生产者线程仅专注于数据移动类似,MMA 线程仅发出 MMA 指令以实现峰值发出率。MMA 线程必须发出多个 MMA 指令,并使这些指令在 TMA 等待屏障上保持运行。
下面展示了kernel 代码的一个摘录,以巩固专门化方面:
// Two types of warp group 'roles'
enum class WarpGroupRole {
Producer = 0,
Consumer0 = 1,
Consumer1 = 2
};
//warp group role assignment
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
使用生产者和张量内存加速器的数据移动
生产者 warp 专注于数据移动 - 具体来说,它们被保持尽可能轻量级,实际上会将一些寄存器空间让给消费者 warp(只保留 40 个寄存器,而消费者将获得 232 个)。它们的主要任务是在共享内存缓冲区被信号标记为空时,立即发出 TMA(张量内存加速器)命令,将数据从全局内存移动到共享内存。
关于 TMA(张量内存加速器)的更多说明,TMA 是随 H100 引入的一个硬件组件,它异步处理从 HBM(全局内存)到共享内存的内存传输。通过拥有专门的硬件单元进行内存移动,工作线程可以从事其他工作,而不是计算和管理数据移动。TMA 不仅处理数据本身的移动,还计算所需的目标内存地址,可以对数据应用任何转换(归约等),并可以处理布局转换,以"交错"模式将数据传递到共享内存,使其可以在没有任何Bank 冲突的情况下使用。最后,如果需要,它还可以将相同的数据多播到属于同一线程集群的其他 SM。一旦数据传递完成,TMA 将向相关的消费者发出信号,表明数据已准备就绪。
CUTLASS 异步流水线类
生产者和消费者之间的这种信号传递通过新的异步流水线类进行协调,CUTLASS 对其描述如下:
"实现持久化 GEMM 算法需要管理数十种不同类型的异步执行操作,这些操作使用组织为循环列表的多个屏障进行同步。
这种复杂性对于人类程序员来说太难手动管理。
因此,我们开发了 CUTLASS Pipeline Async Class..."
Ping-Pong 异步流水线中的屏障和同步
生产者必须通过'producer_acquire'来"获取"给定的共享内存缓冲区。在开始时,流水线是空的,这意味着生产者线程可以立即获取屏障并开始移动数据。
PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
一旦数据移动完成,生产者发出'producer_commit'方法来向消费者线程发出数据准备就绪的信号。然而,对于 Ping-Pong 来说,这实际上是一个空操作指令,因为基于 TMA 的生产者屏障在 TMA 完成写入时会自动更新。
consumer_wait - 等待来自生产者线程的数据(阻塞)。
consumer_release - 向等待的生产者线程发出信号,表明它们已完成消费给定共享内存缓冲区中的数据。换句话说,允许生产者开始用新数据重新填充这个缓冲区。
从那里开始,同步将开始认真进行,生产者将通过阻塞的 producer acquire 等待,直到它们可以获取锁,此时它们的数据移动工作将重复。这将持续到工作完成为止。
提供一个伪代码概述:
//producer
While (work_tile_info.is_valid_tile) {
collective_mainloop.dma() // fetch data with TMA
scheduler.advance_to_next_work()
Work_tile_info = scheduler.get_current_work()
}
// Consumer 1, Consumer 2
While (work_tile_info.is_valid_tile()) {
collective_mainloop.mma()
scheduler.advance_to_next_work()
Work_tile_info = scheduler.get_current_work()
}
以及一个将所有内容与底层硬件结合在一起的鸟瞰图:
图3:Ping-Pong 完整异步流水线概览
补充一下图的细节:
- 主要组件:
- SM 内包含 Consumer MMA Warps 和 Producer DMA Warps
- Tensor Core: 执行矩阵乘法运算
- SMEM (共享内存): 带有异步屏障机制
- RMEM (寄存器内存)
- TMA (张量内存加速器)
- GMEM (全局内存)
- 数据流动路径:
- Producer DMA Warps 通过
cp_async_bulk
指令与 TMA 交互 - TMA 负责在 GMEM 和 SMEM 之间传输数据
- Consumer MMA Warps 通过
wgmma.mma_async
指令从 SMEM 读取数据到 Tensor Core - Tensor Core 计算结果写入 RMEM
- 数据可以多播到其他线程块
- 同步机制:
- Producer 和 Consumer 之间通过 Acquire/Commit 和 Wait/Release 操作进行同步
- SMEM 中的异步屏障用于协调数据访问
- TMA 处理异步数据传输
- 关键特点:
- 整个流程是高度异步的
- 使用专门化的 warp 组实现生产者-消费者模式
- 通过 TMA 实现高效的内存传输
- 支持跨线程块的数据多播
Ping-Pong 计算循环的逐步分解
最后,对 Ping-Pong 处理循环的更详细的逻辑分解:
A - 生产者(DMA)warp 组获取共享内存缓冲区的锁。
B - 这允许它通过单个线程向 tma 芯片发起 tma cp_async.bulk 请求。
C - TMA 计算所需的实际共享内存寻址,并将数据移动到共享内存。作为这个过程的一部分,执行交错操作以便在共享内存中布局数据,以实现最快(无Bank 冲突)的访问。
C1 - 可能的情况下,数据也可以多播到其他 SM,和/或它可能需要等待来自其他 tma 多播的数据以完成加载。(线程块集群现在在多个 SM 之间共享共享内存!)
D - 此时,屏障被更新以向共享内存发出数据到达的信号。
E - 相关的消费者 warp 组现在开始工作,发出多个 wgmma.mma_async 命令,这些命令然后从共享内存读取数据到Tensor Core,作为其 wgmma.mma_async 矩阵乘法操作的一部分。
F - 当Tile 完成时,MMA 累加器值被写入寄存器内存。
G - 消费者 warp 组释放共享内存上的屏障。
H - 生产者 warp 组开始工作,发出下一个 tma 指令以重新填充现在空闲的共享内存缓冲区。
I - 消费者 warp 组同时对累加器应用任何 epilogue 操作,然后将数据从寄存器移动到不同的共享内存缓冲区。
J - 消费者 warp 发出 cp_async 命令,将数据从共享内存移动到全局内存。
这个循环重复进行,直到工作完成。希望这能让你对支持 Ping-Pong 令人印象深刻性能的核心概念有一个工作性的理解。
微基准测试
为了展示 Ping-Pong 的一些性能,下面是一些与我们设计快速推理kernel 相关的比较图表。首先是目前三个最快kernel 的一般基准测试(越低越好):
图4:FP8 GEMM 基准测试时间,数值越低越好(越快)
将其转换为 Ping-Pong 与 cuBLAS 和 Triton 的相对加速图:
图5:Ping-Pong 相对于两个最接近kernel 的速度提升
Ping-Pong kernel 的完整源代码在这里(619 行深度模板化的 CUTLASS 代码,或者用著名的乌龟模因来说就是 - "全都是模板...一直都是!"):
此外,我们已经将 PingPong 实现为 CPP 扩展,使其易于与 PyTorch 集成(同时附带一个简单的测试脚本展示其用法):
最后,为了继续学习,Nvidia 有两个深入探讨 CUTLASS kernel 设计的 GTC 视频:
- Developing Optimal CUDA Kernels on Hopper Tensor Cores | GTC Digital Spring 2023 | NVIDIA On-Demand(https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51413/)
- CUTLASS: A Performant, Flexible, and Portable Way to Target Hopper Tensor Cores | GTC 24 2024 | NVIDIA On-Demand(https://www.nvidia.com/en-us/on-demand/session/gtc24-s61198/)
未来工作
数据移动通常是任何kernel 实现最高性能的最大障碍,因此对 Hopper 上的 TMA(张量内存加速器)有最优策略理解至关重要。我们之前发布了关于 Triton 中 TMA 使用的工作(https://mp.weixin.qq.com/s/cZRoRq_gzAdA2iaMpZ08VA)。一旦在 Triton 中启用了 warp 专门化等功能,我们计划再次深入研究 Triton kernel (如 FP8 GEMM 和 FlashAttention)如何利用 Ping-Pong 等kernel 设计在 Hopper GPU 上加速。