一 模型介绍
是一个CNN+光流对其+双向时许传播的视频增强模型,适合做
视频超分辨率、视频去噪、视频去模糊、视频压缩伪影修复、一般视频增强
BasicVSR系列的核心思想是;不要只增强单帧,而是利用前后多帧的信息。BasicVSR++论文也明确说,基于recurrent structure 通过双向传播和特征对其来利用整个视频序列的信息,BasicVSR++进一步加入二阶传播和flow-guided deformable alignment,增强了对错位视频帧的时空信息利用。
1.1 为何BasicVSR-like模型合适?
模型 适合程度 原因
BasicVSR-lite 最推荐 结构清楚,CNN为主,适合自己实现
EDVR 也推荐,经典视频增强模型,deformable conv实现稍复杂
FastDVDnet 适合视频去噪,不依赖光流,速度快,结构相对直接
BasicVSR++ 效果更强 但结构比BasicVSR更复杂
RealBasicVSR 真实视频超分,适合真实退化视频,但是训练场领略更复杂
优先级
BasicVSR-lite
EDVR
BasicVSR++ / RealBasicVSR
1.2 BasicVSR-lite的整体结构
输入不是一张图,而是一段连续视频帧
输入低质量视频帧
frame_1, frame_2, frame_3, ...frame_T
张量形状一般是
B, T, C,H,W
B = batch size
T 连续帧数量,比如7或15
C = 3,RGB 通道数量
H = 图像高度
W= 图像宽度
整体网络可以写成
连续低质量帧
->每帧CNN提取特征
->光流估计/特征对齐
->反向时序传播
->正向时序传播
->特征融合
->重建网络
->增强后视频帧
1.3 网络结构详细拆解
1.3.1 输入
假设一次输入7帧
x = frame_1, frame_2, ... frame_7
shape 是
- shape = 8, 7, 3, H, W
如果做4倍超分,输入可能是
低清视频帧,B, 7, 3, 64, 64
高清目标帧 B, 7, 3, 256, 256
如果做去噪,去模糊,压缩伪影修复,输入和输出尺寸通常一样
低质量帧 B,7,3,H,W
高质量帧 B, 7, ,3, H, W
1.3.2 每帧特征提取CNN
frame_t ->Conv->RsBlocks->feature_t
feature_t.shape = B, 64, H, W
这里的CNN可以用
COnv2d ResidualBlock
ReLU / LeakyReLU
这部分和人脸模型ResNet思路类似,输出不是512维向量,而是保留二维特征图
人脸识别
B, 3, 112, 112\]-\>CNN-\>\[B, 512
视频增强
B, 3,H,w\]-\>CNN-\>\[B, 64, H, W
视频增强不能太早flatten, 因为需要恢复图像细节
1.3.3 光流估计/帧间对齐
视频增强最大的问题是:
相邻帧内容相似,物理会运动
BasicVSR类模型通常使用光流网络,比如SPyNet, 来估计相邻帧之间的运动,BasicVSR++补充材料里也提到使用pretrained SPyNet作为flow network
光流可以理解成
第t帧的每个像素,应该往哪里移动,才能对齐到t+1帧
feature_{t-1}
根据optical flow warp
对齐到feature_t
1.3.4 双向时许传播
这是BasicVSR的核心
看当前帧附近的几帧,让信息沿着时间传播。
反向缠传播
从视频最后一帧往前传
frame_T->frame_{T-1}->...frame_1
得到每一帧的backward feature
backward_feature_t
正向传播
再从第一帧往后传
frame_1->frame_2 ...frame_T
得到每一帧的forward feature
forward feature_t
最后第t帧可以利用
当前帧特征
前面帧传来的信息
后面帧传来的信息
enhanced_feature_t = fuse(
current_feature_t,
forward_feature_t,
backward_feature_t
)
1.3.5 重建网络
融合后的特征再经过CNN重建成图像
如果是去噪/去模糊/压缩增强
B,64,H,W\]-\>Conv-\>\[B,3,H,W
如果是视频超分辨率,
B, 64, H, W
->PixelShuffle x2
->PixelShuffle x2
->B, 3, 4H, 4W
1.4 如果用EDVR
EDVR时CVPRW 2019的视频恢复模型,
EDVR沦为提示两个关键模块
PCD Alignment 金字塔,及联,可变形卷积对齐
TSA Fusion 时许和空间注意力融合
1.5 需要去噪FastDVDnet
去噪->降低ISO噪声->减少暗光噪声->视频画面变干净
FastDVDnet是CVPR 2020的视频去噪模型,官方仓库提供Pytorch实现,说明它不适用光流估计的视频去噪算法。
不用光流->结构相对简单->速度快->适合视频去噪入门
二 代码实现
import torch`
`import torch.nn as nn`
`import torch.nn.functional as F`
`def flow_wrap(`
`x,flow,padding_mode="border",`
`align_corners=True`
`):`
`#使用光流对特征图做wrap对齐`
`#参数,x: 要背对齐的特征图,shape=[B,C,H,W]`
`# flow: 光流,shape=[B,2,H,W]`
`#flow]:, 0, :, :[表示x方向位移,横向位移`
`#flow[:,1,:,:] 表示y方向位移,就是纵向位移`
`#padding_mode:`
` grid_sample 越界采样时的填充方式。`
` "border" 表示越界时使用边界像素。`
`# align_corners:`
` grid_sample的坐标对齐方式`
`返回 warped_x: 根据flow对齐后的特征图,shape=[B,C,H,W]`
`#取出输入特征图的batch size,通道数,高,宽`
`b,c,h,w = x.size()`
`#确保flow的数据类型和x一致,避免AMP/FP16 时类型冲突`
`flow = flow.to(dtype=x.dtype)`
`#生成y坐标网络,范围时0到H-1`
`#生成x坐标网络,范围是0到W-1`
`grid_y, grid_x = torch.meshgrid(`
` torch.arange(0, h, device=x.device, dtype=x.dtype),`
` torch.arange(0, w, device=x.device, dtype=x.dtype),`
`indexing="ij"`
`)`
`#grid_x原本shape是[H,W]`
`扩展成[B,H,W] 方便和batch内每张图的flow相加`
`grid_x = grid_x.unsqueeze(0).expand(b, -1, -1)`
`#grid_y 原本shape是[H,W]`
`#扩展成[B,H,W]`
`grid_y, grid_y.unsqueeze(0).expand(b, -1, -1)`
`#当前像素为止x坐标 + 光流横向位移`
`#得到需要从原特征图哪个x为止采样`
`vgrid_x = grid_x + flow[:,0,;,);]`
`#当前像素位置坐标y坐标+光流纵向位移`
`得到需要从原特征图哪个y位置采样`
`vgrid_y = grid_y + flow[]`
`#grid_sample要求坐标范围时[-1, 1]`
`#所以要把像素坐标[0, W-1]转换成[-1, 1]`
`if w > 1:`
`vgrid_x = 2.0 * vgrid_x / (w-1) - 1.0`
`else`
`vgrid_x = torch.zeros_like(vgrid_x)`
`#把像素坐标[0, H-1]转换成[-1, 1]`
`if h > 1:`
` vgrid_y = 2.0 * vgrid_y / (h - 1) - 1.0`
` else:`
` vgrid_y = torch.zeros_like(vgrid_y)`
`#grid_sample要求最后一维时[x,y ]`
`#所以这里吧x坐标和y坐标对跌倒最后一维`
`grid = torch.stack(vgrid_x, vgrid_y), dim=-1`
`#根据grid 从x中采样,得到warp后的特征图`
`warped_x = F.grid_sample(`
`x,grid, mode="bilinear", padding_mode=padding_mode,`
`align_cornors=align_cornors,`
`)`
`#返回对齐后的特征图`
`return warped_x`
`class ResidualBlockNoBN(nn.Module):`
`# 不带BatchNorm 的残差块`
`#视频增强,超分模型里经常不用BatchNorm`
`#因为BatchNorm可能影响图像恢复的细节和数值范围`
`def __init__(self, channels, res_scale=1.0):`
`#channels 输入和输出通道数`
`#res_scale 残差缩放系数`
`#可以让残差分支更稳定`
`#初始化nn.Module父类`
`super().__init__()`
`#第一个3x3卷积,通道数不变`
`self.conv1 = nn.Conv2d(`
` channels,`
` chnanels, `
`kernel_size=3,`
` stride=1,`
` padding=1,`
`)`
`#第二个3x3卷积,通道数不变`
`self.conv2 = nn.Conv2d(`
` channels,`
` channels,`
` kernel_size=3,`
` stride=1,`
` padding=1,`
` )`
`#使用LeakyReLU作为激活函数`
`self.relu = nn.LeakyReLU(`
` negative_slope = 0.1`
` inplace=True`
`)`
`#保存残差缩放系数`
`self.res_scale = res_scale`
`def forward(self, x)`
`#前向传播,输入x shape = [B,C,H,W]`
`#输出 out shpe = [B,C,H,W]`
`#保存原始输入,用于残差链接`
`identity - x`
`#第一个卷积`
`out = self.conv1(x)`
`#激活函数`
`out = self.relu(out)`
`#第二个卷积`
`out = self.conv2(out)`
`#残差链接,输出 = 原输入 + 残差分支`
`out = identity + out * self.res_scale`
`#返回残差块输出`
`return out`
`class ResidualBlockWithInputConv(nn.Module):`
`#信用一个卷积吧输入通道变成mid_channels`
`再接多个残差块`
`#`
`def __init__(`
`self,in_channels,mid_channels,num_block`
`);`
` """`
` 参数:`
` in_channels:`
` 输入通道数。`
` mid_channels:`
` 中间特征通道数。`
` num_blocks:`
` 残差块数量。`
` """`
`#初始化父类`
`super().__init__()`
`#用list存放网络层`
`layers = []`
`#输入卷积,吧in_channels变成mid_channels`
`layers.append(`
` nn.Conv2d(`
` in_channels,`
` mid_channels,`
` kernel_size=3,`
` stride=1,`
` padding=1,`
` )`
`)`
`#激活函数`
`layers.append(`
` nn.LeakyReLU(`
`negative_slope=0.1,inplace=True`
`)`
`)`
`#堆叠多个残差块`
`for _ in range(num_blocks):`
` layers.append(`
` ResidualBlockNoBN(`
` channels=mid_channels,`
`)`
`)`
`#把所有层组成一个Sequential`
` self.main = nn.Sequential(*layers)`
`def forward(self, x):`
` """`
` 前向传播。`
` """`
` # 直接把输入送进 Sequential`
` return self.main(x)`
`class TinyFlowNet(nn.Module):`
`#非常简化的光流网络`
`#不是论文里的BasicVSR里的SpyNet 这是为了吧BasicVSR-like结构跑通`
`#输入 img_ref 参考帧,shape=[B,3,H,W]`
`img_supp`
` 支撑帧,相邻帧`
`shape = [B,3,,H,W]`
`def __init__(self, max_flow=20.0)`
`参数,max_flow 限制预测光流的最大像素位移`
`#初始化父类`
`super().__init__()`
`#保存最大光流范围`
`self.max_flow = max_flo`
`e#输入是两张RGB图拼接, 所有通道数是6`
`self.body = nn.Sequential(`
` #第一层卷积,提取浅层特征`
` nn.Conv2d(6, 32, kernel_size=7, stride=1, padding=3),`
`nn.LeakyReLU(0.1, inplace=True),`
`#下采样一次,扩大感受`
` nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),`
` nn.LeakyReLU(0.1, inplace=True),`
` #中间卷积`
` nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#再下次阿阳一次,继续扩大感受`
`nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#中间卷积`
`nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#上采样回较高分辨率`
`nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#再上采样回原图分辨率附近`
`nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True)`
`#中间卷积`
`nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),`
`nn.LeakyReLU(0.1, inplace=True)`
`#上采样回校高分辨率`
`nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True)`
`#再上采样回调分辨率附近`
`nn.ConvTranspose2d(64, 32, kernel_size=4,stride=2,padding=1)`
`nn.LeakyReLU(0.1, inplace=True)`
`#输出` `2通道光流,dx和dy`
`nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)`
`)`
`def` `forward(self, img_ref, img_supp):`
`##前向传播`
`#记录原始图像的高和宽`
`#shape` `[B, 6, H, W]`
`inp` `= torch.cat([img_ref, img_supp], dim=1)`
`#预测光流`
`flow` `=` `self.body(inp)`
`#如果因为下采样,上采样导致尺寸略有差异,就插值回原尺寸`
`if` `flow.shape[-2:]` `!=` `(h, w):`
`flow` `=` `F.interpolate(`
`flow,`
`size(h,w),`
`mode="bilinear",`
`align_corners=False`
`)`
`#用tanh` `把输出限制到[-1,1]`
`#再乘max_flow` `得到像素激光流范围`
`flow` `= torch.tanh(flow)` `*` `self.max_flow`
`#返回光流`
`return flow`
`class` `BasicVSRLite(nn.Module):`
`#教学版,BasicVSR-lite`
`#整体结构` `输入视频帧序列`
`#CNN提取每帧特征`
`#估计相邻帧光流`
`#反向时间传播T->1`
`正向时间传播1->T`
`当前帧特征` `+` `反向传播特征` `+` `正向传播特征` `融合`
`重建增强帧` `/` `超分帧`
`输入:`
`x` `shpe` `=` `[B, T, 3, H, W]`
`输出`
`scale` `=` `1:`
`out shape=[B, T, 3, H, W]`
`scale=2:`
`out shape` `=` `[B, T, 3, 2H, 2W]`
`scale` `= 4:`
`out shape` `=` `[B, T, 3, 4h, 4w]`
`def` `__init__(`
`self,` `mid_channels=64,`
`num_feature_blocks=5,`
`num_propagation_blocks=7,`
`scale=1,`
`max_flow=20.0,`
`)` `:`
`参数`
`mid_channels:中间特征通道`
`64` `是常见的轻量配置`
`num_feature_blocks:`
`每帧特征提取阶段的残差块数量`
`num_propagation_blocks`
`正向/反向传播阶段的残差块数量`
`num_reconstruction_blocks`
`重建阶段的残差块数量`
`scale:`
`放大倍率`
`scale=1` `表示输入输出同尺寸,同于去噪,去模糊,增强`
`scale=2` `表示2倍超分`
`scale=4` `表示4倍超分`
`max_flow` `TinyFlowNet` `预测光流最大像素位移`
`#初始化分类`
`super().__init__()`
`#只允许1,2,4三种倍率`
`assert` `scale` `in` `(1,2,4),` `"scale must be 1,2, or 4"`
`#保存中间通道数`
`self.mid_channels` `=` `mid_channels`
`#保存超分倍率`
`self.scale=scale`
`#光流网络,用于估计相邻帧之间的运动`
`self.flow_net` `=` `TinyFlowNet(`
`max_flow` `=` `max_flow,`
`)`
`#每帧特征提取网络`
`feature_layers=[]`
`#第一层卷积,RGB图像3通道->mid_channels`
`feature_layers.append(`
`nn.Conv2d(`
`3,` `mid_channels,kernel_size=3,`
`stride=1,padding=1`
`)`
`)`
`#激活函数`
`feature_layers.append(`
`nn.LeakyReLU(`
`negative_slope=0.1,` `inplace=True`
`)`
`)`
`#堆叠多个残差块,用于提取每一帧的空间特征`
`for` `_` `in` `range(num_feature_blocks):`
`feature_layers.append(`
`ResidualBlockNoBN(`
`channels=mid_channels,`
`)`
`)`
`#组成每帧特征提取网络`
`self.feat_extract` `=` `nn.Sequential(*feature_layers)`
`#反向传播网络`
`#输入是当前帧特征` `+` `从未来帧传播过来的特征`
`#所以输入通道数mid_channels` `*` `2`
`self.backward_trunk` `= ResidualBlockWithInputConv(`
`in_channels` `= mid_channels` `* 2,`
`mid_channels` `=` `mid_channels,`
`num_blocks=nm_propagation_blocks`
`)`
`#正向传播网络`
`#输入是当前帧特征` `+` `从过去帧传播过来的特征`
`#所以输入通道数也是mid_channels` `*2`
`self.forward_trunk` `= ResidualBlocksWithInputConv(`
`in_channels` `= mid_channels` `*` `2,`
`mid_channels` `=` `mid_channels,`
`num_blocks` `= num_propagation_blocks,`
`)`
`#重建网络,`
`#输入是当前帧特征` `+` `反向传播特征` `+` `正向传播特征`
`#` `所以输入通道数是mid_channels` `*` `3`
`self.reconstruction` `= ResidualBlocksWithInputConv(`
`in_channels` `= mid_channels` `*` `3,`
`mid_channels` `= mid_channels,`
`num_blocks` `=` `num_reconstruction_blocks,`
`)`
`#激活函数`
`self.lrelu` `= nn.LeakyReLU(`
`negative_slope=0.1,`
`inplace=True,`
`)`
`#pixelShuffle用于超分辨率上采样`
`self.pixel_shuffle = nn.PixelShuffle(`
`upscale_factor=2,`
`)`
`#如果scale` `>=2` `需要一次2倍上采样`
`if` `scale` `>=` `2:`
`self.upconv1` `=` `nn.Conv2d(`
`mid_channels,`
`mid_channels` `*` `4,`
`kernel_size=3,`
`stride=1,`
`padding=1,`
`)`
`#如果scale==4需要两次2倍上采样`
`if` `scale` `== 4:`
`self.upconv2` `= nn.Conv2d(`
`mid_channels,`
`mid_channels` `*4,`
`kernel_size=3,`
`stride=1,`
`padding=1,`
`)`
`#高分辨率空间上的卷积`
`self.conv_hr` `= nn.Conv2d(`
`mid_channels,`
`mid_channels,`
`kernel_size=3,`
`stride=1,`
`padding=1,`
`#最后一层卷积,把特征图变回RGB图像`
`)`
`self.conv_last` `= nn.Conv2d(`
`mid_channels,`
`3,`
`kernel_size=3,`
`stride=1,`
`padding=1,`
`)`
`def` `compute_flows(self, x):`
`#计算相邻帧之间的光流`
`#输入`
`x shape =` `[B, T, 3, H ,W]`
`返回` `flow_forward:`
`用于正向传播`
`flows_forward[:, i-1]表示第i帧` `->` `第i-1帧的光流`
`用它可以把过去帧特征wrap到当前帧`
`shape=[B,T-1, 2, H, W]`
`flows_backward:`
`用于反向传播,flows_backward[:,i]表示第i帧` `第i` `+ 1帧的光流`
`用它可以把未来帧特征warp到当前帧`
`shape` `= [B,T-1,2,H,W]`
`#取出输入视频的维度`
`b,t,c,h,w =` `x.size()`
`#如果只有1帧,就没有相邻帧光流`
`if t` `<=1:`
`empty` `= x.new_zeros(b, 0, ,2, h, w)`
`return empty, empty`
`#存放反向传播需要的光流`
`flows_backward` `=` `[]`
`#对于反向传播,需要从未来帧传播到的当前帧`
`#warp` `future` `feature` `到当前帧时,需要当前帧` `->未来帧` `的光流`
`for` `i in range(t - 1):`
`#计算第i帧到第i` `+` `1帧的光流`
`flow_i_to_next =` `self.flow_net(`
`x[:, i, :, :, :],`
`x[:, i + 1, :, :, :],`
`)`
`#保存光流`
`flows_backward.append(flow_i_to_next)`
`#把list堆叠成tensor`
`#shape` `= [B, T-1, 2, H, W]`
`flows_backward=torch.stack(`
`flows_backward,`
`dim=1,`
`)`
`#存档正向传播需要的光流`
`flows_forward=[]`
`#对于正向传播,需要从过去帧传播到当前帧`
`#warp` `past feature 到当前帧时,需要当前帧-》过去帧的光流`
`for i in range(1, t):`
`#计算第i帧到第i-1帧的光流`
`flow_i_to_prev` `= self.flow_net(`
`x[:, i, :, :, :],`
` x[:, i - 1, :, :, :],`
`)`
`#保存光流`
`flows_forward.append(flow_i_to_prev)`
`#shape = [B, T-1, 2, H, W]`
`flow_forwards` `= torch.stack(`
`flows_forward,`
`dim=1`
`)`
`#返回正向传播光流和反向传播光流`
`return` `flows_forward,` `flows_backward`
`def` `upsample(self, feat):`
`根据scale对重建特征进行上采样`
`输入:` `feat shape` `= [B, C, H, W]`
`输出` `scale=1`
`out shape` `= [B, 3, H, W]`
`scale=2`
`out` `shape=[B, 3, 2H, 2W]`
`scale=4:`
`out` `shape=` `[B, 3, 4H, 4W]`
`#如果是2倍或4倍超分,先做一次2倍PixelShuffl`
`eif self.scale` `== 2:`
`#卷积把通道扩展到4倍`
`feat = self.upconv1(feat)`
`#pixelShuffle把通道转换为空间分辨率`
`feat = self.pixel_shuffle(feat)`
`#激活`
`feat =` `self.lrelu(feat)`
`#如果是4倍超分,需要做两次2倍PixelShuffle`
`elif self.scale ==` `4`
`:#第一次2倍上采样`
`feat` `= self.upconv1(feat)`
`feat = self.pixel_shuffle(feat)`
`feat = self.lrelu(feat)`
`#第二次2倍上采样`
`feat = self.upconv2(feat)`
`feat = self.pixel_shuffle(feat)`
`feat = self.lrelu(feat)`
`#高分辨率卷积`
`feat = self.conv_hr(feat)`
`#激活`
`feat = self.lrelu(fea)t`
`#输出RGB残差图像`
`out = self.conv_last(feat)`
`#返回输出`
`return out`
`def get_base_frame(self, lr_frame):`
`获取残差链接里的base image`
`#对于scale=1`
`base就是原输入帧`
`对于scale=2或scale=4`
`bas是双线性循环放大后的输入帧`
`最终输出`
`enhanced = predicted_residual` `+ base`
`#如果不做超分,直接返回原图`
`if self.scale` `== 1:`
`return` `lr_frame`
`#如果做超分,吧低清晰度输入双线性插值放大`
`base = F.interpolate(`
`lr_frame,`
`scale_factor=self.scale,`
`mode` `= "bilinear",`
`align_corners=False`
`)`
`#返回base` `frame`
`return base`
`def forward(self, x):`
` """`
` 前向传播。`
` 输入:`
` x shape = [B, T, 3, H, W]`
` 输出:`
` out shape = [B, T, 3, H*scale, W*scale]`
` """`
`#检查输入必须是5倍`
`if x.dim()` `!= 5`
`#取出输入视频的维度`
`b,t,c,h,w` `= x.size()`
`检查必须是RGB视频`
`#每帧CNN特征提取`
`#把[B,T,3,H,W]reshape成[B*t, 3, H ,W]`
`#这样可以一次性把所有帧送进CNN`
`x_reshape = x.reshape(b*t, c, h, w)`
`#提取每帧空间特征`
`feats` `= self.feat_extract(x_reshape)`
`#把特征reshape回视频序列形式`
`#shape` `= [B,T,mid_channels, H,W]`
`feats =` `feats.reshape(`
`b,t,self.mid_channels,`
`h,w`
`)`
`#计算相邻帧光流`
`flows_forward用于正向传播`
`flows_backward` `用于反向传播`
`flows_forward, flows_backward` `= self.compute_flows(x)`
`#反向时间传播` `从T-1帧传播到第0帧`
`#用list存放每一帧反向传播特征`
` backward_feats = [None] * t`
`#初始化传播特征为全0`
`#shape =` `[B, mid_channels, H, W]`
`feat_prop = x.new_zeros(b, self.mid_channels,h,w)`
`#从最后一帧住第一帧遍历`
`for i in range(t - 1, -1, -1):`
`#如果不是最后一帧,就需要把未来帧传播特征warp到当前帧`
`if` `i <` `t - 1:`
`#flows_backward[:, i]是第i帧->第i+1帧的光流`
`#用它可以把第i + 1帧的传播特征对齐到第i帧`
` feat_prop = flow_warp` `(`
`feat_prop,`
`flows_backward[:, i,:,:,:],`
` )`
`#当前帧特征`
`curr_feat` `= feats[L,i,:,:,:]`
`#拼接当前帧特征和传播特征`
`#shape` `= [B, mid_channel *2, H, W]`
`feat_input = torch.cat(`
`[curr_feat, feat_prop],`
`dim = 1,`
`)`
`#通过反向传播网络更新传播特征`
`#shape` `= [B, mid_channels * 2, H, W]`
`feat_input = torch.cat(`
`[curr_feat, feat_prop],`
`dim=1,`
`)`
`#通过反向传播网络更新传播特征`
`feat_prop = self.backward_trunk(feat_input)`
`#保存第i帧对应的反响传播特征`
`backward_feats[i]` `= feat_prop`
`#4` `正向时间传播,从第0帧传播到T-1帧`
`#用list存放每一帧的正向传播特征`
`forwards_feats =` `[None]*t`
`#初始化正向传播特征为全0`
`feat_prop = x.new_zeros(b,self.mid_channels, h,w)`
`#从第一帧往后一帧遍历`
`for i in range(t):`
`#如果不是第一帧,就需要把过去帧传播特征warp到当前帧`
`if` `i > 0`
`#dlows` `forward[:, i-1]是第i帧` `第i-1帧的光流`
`#用它可以把第i-1帧的传播特征对其道第i帧`
`feat_prop` `= flow_warp(`
`feat_prop,`
`flow_forward[:,i-1,:,:,:]`
`)`
`#当前帧特征`
`curr_feat = feats[:,i,:,:,:]`
`#拼接当前帧特征和正向传播特征`
`feat_input = torch.cat(`
`[curr_feat, feat_prop],`
`dim=1,`
`)`
`#通过正向传播网更新传播特征`
`feat_prop = self.forward_trunk(feat_input)`
`#保存第i帧对应的正向传播特征`
`forward_feats][i = feat_prop`
`#融合当前帧特征,反向传播特征,正向传播特征,` `并重建输出帧`
`#存放所有输出帧`
`outs=[]`
`#对每一帧分别重建`
`for i in` `range()t:`
`#当前帧的原始空间特征`
`curr_feat = feats[:,i,:,:,:]`
` #当前帧的反向传播特征`
`backward_feat =` `backward_feats[i]`
`#当前帧的正向传播特征`
`forward_feat = forward_feats[i]`
`#三类特征拼接`
`#shape = [B,mid_channels * 3, H,W]`
`feat = torch.cat(`
`[currefeat, backward_feat, forward_feat],`
`dim=1,`
`)`
`#通过重建网络重建特征`
`feat` `= self.reconstruction(feat)`
`#根据scale输出RGB残差图像`
`out = self.upsample(feat)`
`#获取base frame`
`#scale=1` `时就是原输入帧`
`#scale= 2/4` `时是双线性插值放大后的输入帧`
`base = self.get_base_frame(`
`x[:,i,:,:,:]`
`)`
`#残差学习,最终输出` `网络预测残差` `+` `base`
`out = out +` `base`
`#保存当前输出帧`
`outs.append(out)`
`#把list里的每一帧堆叠回视频序列`
`#shape` `= [B,T,3, H *scale, W*scale]`
`outs = torch.stack(outs, dim=1)`
`#返回增强后的视频帧序列`
`return outs`
`if __name__ == "__main__":`
`简单测试代码`
`直接运行` `python basicvsr_lite.py`
`#构造一个BasicVSR-lite模型`
`#scale=1` `表示输入输出同分辨率`
`model = BasicVSRLite(`
`mid_channels=64,`
`num_feature_blocks=5,`
`num_propagation_blocks=7,`
`num_reconstruction_blocks=10,`
`scale=1,`
`)`
`#构造一个假的输入视频batch`
`#B=2,T=7,C=3,H=64,W=64`
`x = torch.randn(2,7,3,64,64)`
`#前向传播`
`y = model(x)`
`#打印输入输出尺寸`
`#训练时一般这样计算损失`
`#假设gt是清晰视频帧,shape和y一样`
`gt = torch.randn_like(y)`
`#视频增强/超分常用L1Loss`
`loss = F.l1_loss(y, gt)`
`
二 模型总结
这个模型是一个教学版BsicVSR-lite视频增强模型。核心目的不会单张图片,而是处理一段连续视频帧,利用前后帧信息增强当前帧。`
`输入一段低质量视频帧`
`提取每一帧的CNN特征`
`估计相邻之间的光流`
`用光流吧前后帧特征对齐`
`做正向和反向时序信息传播`
`融合当前帧,过去帧,未来帧信息。`
`1、这个模型能做什么`
`视频去噪`
`视频去模糊`
`视频压缩伪影修复`
`视频画质增强`
`视频超分辩率`
`由scale控制任务类型`
`scale=1`
`表示输入输出同尺寸,适合葡萄视频增强`
`scale=2`
`表示2倍视频超分`
`scale=4`
`表示4倍视频超分`
`2输入输出格式`
`x.shape =` `[B,T,3,H,W]`
`含义是`
`B` `= batch_size` `一次训练几个视频片段`
`T` `=` `每个视频片段有多少帧`
`3` `= RGB通道数量`
`H` `图像高度`
`W` `图像宽度`
`x` `= torch.randn(2, 7, 3, 64, 64)`
`表示`
`2` `个视频片段`
`每个片段7帧`
`每帧是RGB图像`
`每帧大小64x64`
`如果scale=1` `输出是`
`y.shape =` `[B,T,3,H,W]`
`如果scale=4` `输出是`
`y.shape=[B,T,3,4H,4W]`
`3` `模型整体网络结构`
`class` `BasicVSRLite(nn.Module)`
`里面主要有这些模块`
`feat_extract` `每帧CNN特征提取`
`flow_net` `简化光流估计网络`
`backward_trunk` `反向时间传播网络`
`forward_trunk 正向时间传播网络`
`reconstruction` `特征融合和重建网络`
`upsample` `超分上采样模块`
`residual` `base残差输出链接`
`整体结构可以画成`
`输入视频x:[B,T,3,H,W]`
`每帧CNN特征提取`
`feats` `[B,T,64,H,W]`
`计算相邻帧光流`
`反向传播,T-1->0` `得到未来信息backward_feats`
`正向传播` `0->T-1` `得到过去信息forward_feats`
`每一帧融合`
`当前帧特征` `过去信息` `未来信息`
`重建网络`
`输出增强视频`
`4` `flow_warp是干什么的?`
`def` `flow_warp(x, flow)`
`作用是,根据光流把特征图进行空间对齐`
`视频里物体会运动,`
`第一帧` `人脸在左边`
`第二帧,人脸在中间`
`第三针,人脸在右边`
`直接把这些帧的特征融合,会无法对齐,结果容易模糊`
`需要先用光流估计运动`
`这个额像素从上一帧移动到了哪里`
`这个特征应该往左还是往右移动`
`F.grid_sample(...)`
`重新采样特征图,吧前后特征对齐到当前帧`
`flow_warp` `=` `根据运动信息移动特征图`
`5` `ResidualBlockNoBN是什么`
`这是一个不带BatchNorm的残差块`
`class ResidualBlockNoBN(nn.Module)`
`结构是`
`输入x`
`Conv` `3x3`
`LeakyReLU`
`Conv` `3x3`
`加回输入x`
`输出`
`对应代码`
`identity=x`
`out =` `self.conv1()x`
`out =` `self.relu(out)`
`out` `= self.conv2(out)`
`out =` `identity` `+` `out *` `self.res_scale`
`作用是增强特征表达能力`
`为什么没有BatchNorm 归一化一下?`
`因为视频增强超分,去噪这类图像恢复任务需要保留非常惊喜的像素信息,BatchNorm有时会破坏图像的亮度,颜色,纹理分布,所以很多恢复模型不用BN`
`6` `TinyFlowNet是什么?`
`class TinyFlowNet(nn.Module)`
`这是一个简化版本光流估计网络`
`输入两帧图像`
`img_ref.shape=[B,3,H,W]`
`img_supp.shape` `= [B,3, H,W]`
`先拼接成`
`[B,6,H,W]`
`然后经过一个小CNN,输出`
`flow.shape = [B,2,H,W]`
`其中`
`flow[:, 0,:,)]` `= x方向位移`
`flow]:,1,:,:[` `= y` `方向位移`
`这个TInyFlowNet是教学版,不是官方BasicVSR里的SpyNet,`
`7每帧特征提取feat_extract`
`这一段代码`
`self.feat_extract` `= nn.Sequential(*feature_layers)`
`作用把每一帧RGB图像变成CNN特征图`
`输入一帧`
`[B,3,H,W]`
`输出` `[B,64,H,W]`
`mid_channels =` `64`
`所以每帧被转换成64通道的空间特征`
`注意这里不是输出一个向量,保留二维空间结构`
`图片` `CNN` `512维向量`
`图片` `CNN` `64通道特征图`
`因为视频增强最终需要回复图像,不能把空间信息压扁。`
`8` `compute` `flows计算什么?`
`flows_forward, flows_backward = self.compute_flos(x)`
`这个函数计算相邻帧之间的光流`
`flows_backward`
`flows_forward`
`flows_backward用于反向传播,就是从未来帧往当前帧传信息`
`代码里面计算是这样的`
`flow_i_to_next = self.flow_net` `(`
`x[:, i]`
`x[:, i + 1]`
`)`
`用它可以把低i+1帧的特征warp到第i帧。`
`flows_forward`
`用于正向传播,从过去帧往当前帧传信息`
`flow_i_to_prev` `= self.flow_net(`
`x[:, i]`
` x[:, i + 1]`
`)`
`含义是`
`第i帧到第i+1帧的光流`
`9反向时间传播`
`for i in` `range(t-1, -1, -1)`
`从最后一帧往第一帧处理`
`T-1,` `T-2 ,0`
`目的`
`让当前帧获得未来帧的信息`
`例如处理第3帧时,利用第4,5,6帧传过来的信息`
`初始化feat_prop = 0`
`从最后一帧开始`
`如果不是最后一帧`
`用光流warp` `未来帧传来的feat_prop`
`拼接`
`当前帧特征curr_feat`
`未来传播特征` `feat_prop`
`送入backward_trunk`
`得到新的feat_prop`
`保存为backward_feats[i]`
`核心代码`
`feat_prop =` `flow_warp` `(`
`feat_prop,flows_backward[:,i]`
`)`
`feat_input= torch.cat(`
`]curr_feat, feat_prop[`
`dim=1`
`)`
`feat_prop=self.backward_trunk(feat_input)`
`10` `正向时间传播`
`for i in range(t)`
`例如处理第3帧时,可以利用第0,1,2帧传过来的信息`
`初始化feat_prop` `= 0`
`从第一帧开始`
`如果不是第一帧`
`用光流warp过去帧传来的feat_prop`
`拼接`
`当前帧特征curr_feat`
`过去传播特征` `feat_prop`
`送入forward_tru`
`nk得到新的feat_prop`
`保存为forward_feats[i]`
`核心代码`
`feat_prop = flow_warp(`
` feat_prop,`
` flows_forward[:, i - 1]`
`)`
`feat_input = torch.cat(`
` [curr_feat, feat_prop],`
` dim=1`
`)`
`feat_prop = self.forward_trunk(feat_input)`
`11 最终融合重建`
`1curr_feat当前帧呢自己的特征`
`2` `backward_feat 未来帧传来的信息`
`3` `forward_feat过去帧传来的信息`
`然后拼接`
`feat = torch.cat([curr_feat, backward_feat, forward_feat],`
`dim=1)`
`如果mid_channels=64` `那么`
`curr_feat` `= 6通道`
`backward_feat = 64通道`
`forward_feat = 64通道`
`拼接后` `=` `192通道`
`送入重建网络`
`feat = self.reconstruction(fe)at`
`重建网络把192通道融合成64通道,`
`out = self.upsample(feat)输出RGB图像`
`12 upsample是什么`
`def upsample(self, feat)`
`负责吧特征图转换成最终RGB输出`
`如果scale=1`
`不做放大`
`[B,64,H,W]->[B,3,H,W]`
`如果scale=2`
`Conv -> PixelShuffle x2 -> Conv -> RGB`
`[B,64,H,W] -> [B,3,2H,2W]`
`如果scale=4`
`Conv -> PixelShuffle x2`
`Conv -> PixelShuffle x2`
`Conv -> RGB`
`[B,64,H,W] -> [B,3,4H,4W]`
`PixelShuffle是超分模型里常用的上采样方式,通道维的信息重新排列空间维上。`
`13 为什么最后要out + base`
`base = self.get_base_frame(x[:, i])`
`out = out + base`
`残差学习`
`如果scale` `= 1`
`base` `= 原始输入帧`
`输出=原始输入帧` `+ 模型预测的修正量`
`如果scale=4`
`base = 双线性插值方法后的输入帧`
`输出=放大后的输入帧` `+` `模型预测的高频细节`
`因为模型从不从零生成整张图,`
`哪里需要去噪` `哪里需要变清晰` `哪里需要补充细节` `哪里需要修复伪影`
`14 forward函数完整流程`
`你的forward可以总结成`
`1检查输入维度`
`2 把视频帧展开`
`[B,T,3,H,W]->]B*T,3,H,W[`
`3 每帧CNN特征提取`
`]B*T, 3,H,W[->[B*T,64,H,W]`
`4` `reshape回视频形式`
`[B*T,64,H,W]->[B,T,64,H,W]`
`5` `计算相邻帧光流`
`flows_forward = [B,T-1,2,H,W]`
`flows_backward=[B,T-1,2,H,W]`
`6` `反向时间传播`
`得到backward_feats`
`7 正向时间传播`
`得到forward_feat`
`s8 对每一帧`
`当前帧呢特征` `未来信息` `+过去信息`
`reconstruction`
`upsample`
`加base` `frame`
`输出增强帧`
`9` `把所有输出帧` `stack`
`[B,T,3,H*scale,W*scale]`
`15 模型训练时怎么用`
`普通视频增强`
`model` `= BasicVSRLite(scale=1).cuda()`
`lq = torch.randn(2,7,3,128,128).cuda(`
`)gt = torch.randn(2,7,3,128,128).cuda()`
`pred = model(lq)`
`loss = F.l1_loss(pred, gt)`
`loss.backward()`
`倍视频超分`
`model = BasicVSRLite(scale=4).cuda()`
`lq = torch.randn(2, 7, 3, 64, 64).cuda()`
`gt = torch.randn(2, 7, 3, 256, 256).cuda()`
`pred = model(lq)`
`loss = F.l1_loss(pred, gt)`
`loss.backward()`
`其中,`
`lq = low quality低质量视频帧`
`gt = ground truth` `高质量视频帧`
`16这个模型的学习重点`
`16.1 CNN特征提取`
`self.feat_extract`
`16.2` `光流对齐`
`TinyFlowNet` `+ flow_warp`
`16.3 双向时间传播`
`backwrd_trunk + forward_trunk`
`16.4` `融合重建`
`reconstruction` `+ upsample + out +base`
`当前帧不清楚的地方,可以从前后帧里找信息补回来`
`17 需要注意的地方`
`1. TinyFlowNet` `光流效果不一定好`
`2. 没有使用预训练SpyNet /` `PWC`
`Net美元BasicVSR+的二阶传播`
`4没有EDVR的可变性卷积对齐`
`5` `只用了基础L1Loss视觉锐度可能一般`
`