BasicVSR-lite图像画质增强

一 模型介绍

是一个CNN+光流对其+双向时许传播的视频增强模型,适合做

视频超分辨率、视频去噪、视频去模糊、视频压缩伪影修复、一般视频增强

BasicVSR系列的核心思想是;不要只增强单帧,而是利用前后多帧的信息。BasicVSR++论文也明确说,基于recurrent structure 通过双向传播和特征对其来利用整个视频序列的信息,BasicVSR++进一步加入二阶传播和flow-guided deformable alignment,增强了对错位视频帧的时空信息利用。

1.1 为何BasicVSR-like模型合适?

模型 适合程度 原因

BasicVSR-lite 最推荐 结构清楚,CNN为主,适合自己实现

EDVR 也推荐,经典视频增强模型,deformable conv实现稍复杂

FastDVDnet 适合视频去噪,不依赖光流,速度快,结构相对直接

BasicVSR++ 效果更强 但结构比BasicVSR更复杂

RealBasicVSR 真实视频超分,适合真实退化视频,但是训练场领略更复杂

优先级

BasicVSR-lite

EDVR

BasicVSR++ / RealBasicVSR

1.2 BasicVSR-lite的整体结构

输入不是一张图,而是一段连续视频帧

输入低质量视频帧

frame_1, frame_2, frame_3, ...frame_T

张量形状一般是

B, T, C,H,W

B = batch size

T 连续帧数量,比如7或15

C = 3,RGB 通道数量

H = 图像高度

W= 图像宽度

整体网络可以写成

连续低质量帧

->每帧CNN提取特征

->光流估计/特征对齐

->反向时序传播

->正向时序传播

->特征融合

->重建网络

->增强后视频帧

1.3 网络结构详细拆解

1.3.1 输入

假设一次输入7帧

x = frame_1, frame_2, ... frame_7

shape 是

  1. shape = 8, 7, 3, H, W

如果做4倍超分,输入可能是

低清视频帧,B, 7, 3, 64, 64

高清目标帧 B, 7, 3, 256, 256

如果做去噪,去模糊,压缩伪影修复,输入和输出尺寸通常一样

低质量帧 B,7,3,H,W

高质量帧 B, 7, ,3, H, W

1.3.2 每帧特征提取CNN

frame_t ->Conv->RsBlocks->feature_t

feature_t.shape = B, 64, H, W

这里的CNN可以用

COnv2d ResidualBlock

ReLU / LeakyReLU

这部分和人脸模型ResNet思路类似,输出不是512维向量,而是保留二维特征图

人脸识别

B, 3, 112, 112\]-\>CNN-\>\[B, 512

视频增强

B, 3,H,w\]-\>CNN-\>\[B, 64, H, W

视频增强不能太早flatten, 因为需要恢复图像细节

1.3.3 光流估计/帧间对齐

视频增强最大的问题是:

相邻帧内容相似,物理会运动

BasicVSR类模型通常使用光流网络,比如SPyNet, 来估计相邻帧之间的运动,BasicVSR++补充材料里也提到使用pretrained SPyNet作为flow network

光流可以理解成

第t帧的每个像素,应该往哪里移动,才能对齐到t+1帧

feature_{t-1}

根据optical flow warp

对齐到feature_t

1.3.4 双向时许传播

这是BasicVSR的核心

看当前帧附近的几帧,让信息沿着时间传播。

反向缠传播

从视频最后一帧往前传

frame_T->frame_{T-1}->...frame_1

得到每一帧的backward feature

backward_feature_t

正向传播

再从第一帧往后传

frame_1->frame_2 ...frame_T

得到每一帧的forward feature

forward feature_t

最后第t帧可以利用

当前帧特征

前面帧传来的信息

后面帧传来的信息

enhanced_feature_t = fuse(

current_feature_t,

forward_feature_t,

backward_feature_t

)

1.3.5 重建网络

融合后的特征再经过CNN重建成图像

如果是去噪/去模糊/压缩增强

B,64,H,W\]-\>Conv-\>\[B,3,H,W

如果是视频超分辨率,

B, 64, H, W

->PixelShuffle x2

->PixelShuffle x2

->B, 3, 4H, 4W

1.4 如果用EDVR

EDVR时CVPRW 2019的视频恢复模型,

EDVR沦为提示两个关键模块

PCD Alignment 金字塔,及联,可变形卷积对齐

TSA Fusion 时许和空间注意力融合

1.5 需要去噪FastDVDnet

去噪->降低ISO噪声->减少暗光噪声->视频画面变干净

FastDVDnet是CVPR 2020的视频去噪模型,官方仓库提供Pytorch实现,说明它不适用光流估计的视频去噪算法。

不用光流->结构相对简单->速度快->适合视频去噪入门

二 代码实现

复制代码
import torch`
`import torch.nn as nn`
`import torch.nn.functional as F`

`def flow_wrap(`
`x,flow,padding_mode="border",`
`align_corners=True`
`):`
`#使用光流对特征图做wrap对齐`
`#参数,x: 要背对齐的特征图,shape=[B,C,H,W]`
`# flow: 光流,shape=[B,2,H,W]`
`#flow]:, 0, :, :[表示x方向位移,横向位移`
`#flow[:,1,:,:] 表示y方向位移,就是纵向位移`
`#padding_mode:`
`            grid_sample 越界采样时的填充方式。`
`            "border" 表示越界时使用边界像素。`
`# align_corners:`
`  grid_sample的坐标对齐方式`
`返回 warped_x: 根据flow对齐后的特征图,shape=[B,C,H,W]`
`#取出输入特征图的batch size,通道数,高,宽`
`b,c,h,w = x.size()`
`#确保flow的数据类型和x一致,避免AMP/FP16 时类型冲突`
`flow = flow.to(dtype=x.dtype)`
`#生成y坐标网络,范围时0到H-1`
`#生成x坐标网络,范围是0到W-1`
`grid_y, grid_x = torch.meshgrid(`
`    torch.arange(0, h, device=x.device, dtype=x.dtype),`
`    torch.arange(0, w, device=x.device, dtype=x.dtype),`
    `indexing="ij"`
`)`
`#grid_x原本shape是[H,W]`
`扩展成[B,H,W] 方便和batch内每张图的flow相加`
`grid_x = grid_x.unsqueeze(0).expand(b, -1, -1)`
`#grid_y 原本shape是[H,W]`
`#扩展成[B,H,W]`
`grid_y, grid_y.unsqueeze(0).expand(b, -1, -1)`
`#当前像素为止x坐标 + 光流横向位移`
`#得到需要从原特征图哪个x为止采样`
`vgrid_x = grid_x + flow[:,0,;,);]`
`#当前像素位置坐标y坐标+光流纵向位移`
`得到需要从原特征图哪个y位置采样`
`vgrid_y = grid_y + flow[]`

`#grid_sample要求坐标范围时[-1, 1]`
`#所以要把像素坐标[0, W-1]转换成[-1, 1]`
`if w > 1:`
    `vgrid_x = 2.0 * vgrid_x / (w-1) - 1.0`
`else`
    `vgrid_x = torch.zeros_like(vgrid_x)`
    
`#把像素坐标[0, H-1]转换成[-1, 1]`
`if h > 1:`
`        vgrid_y = 2.0 * vgrid_y / (h - 1) - 1.0`
`    else:`
`        vgrid_y = torch.zeros_like(vgrid_y)`
`#grid_sample要求最后一维时[x,y ]`
`#所以这里吧x坐标和y坐标对跌倒最后一维`
`grid = torch.stack(vgrid_x, vgrid_y), dim=-1`

`#根据grid 从x中采样,得到warp后的特征图`
`warped_x = F.grid_sample(`
`x,grid, mode="bilinear", padding_mode=padding_mode,`
`align_cornors=align_cornors,`
`)`
`#返回对齐后的特征图`
`return warped_x`

`class ResidualBlockNoBN(nn.Module):`
`# 不带BatchNorm 的残差块`
`#视频增强,超分模型里经常不用BatchNorm`
`#因为BatchNorm可能影响图像恢复的细节和数值范围`
`def __init__(self, channels, res_scale=1.0):`
`#channels 输入和输出通道数`
`#res_scale 残差缩放系数`
`#可以让残差分支更稳定`
`#初始化nn.Module父类`
`super().__init__()`
`#第一个3x3卷积,通道数不变`
`self.conv1 = nn.Conv2d(`
` channels,`
` chnanels, `
 `kernel_size=3,`
` stride=1,`
` padding=1,`
`)`
`#第二个3x3卷积,通道数不变`
`self.conv2 = nn.Conv2d(`
`            channels,`
`            channels,`
`            kernel_size=3,`
`            stride=1,`
`            padding=1,`
`        )`
`#使用LeakyReLU作为激活函数`
`self.relu = nn.LeakyReLU(`
`  negative_slope = 0.1`
` inplace=True`
`)`
`#保存残差缩放系数`
`self.res_scale = res_scale`

`def forward(self, x)`
`#前向传播,输入x shape = [B,C,H,W]`
`#输出 out shpe = [B,C,H,W]`
`#保存原始输入,用于残差链接`
`identity - x`
`#第一个卷积`
`out = self.conv1(x)`
`#激活函数`
`out = self.relu(out)`
`#第二个卷积`
`out = self.conv2(out)`
`#残差链接,输出 = 原输入 + 残差分支`
`out = identity + out * self.res_scale`
`#返回残差块输出`
`return out`

`class ResidualBlockWithInputConv(nn.Module):`
`#信用一个卷积吧输入通道变成mid_channels`
`再接多个残差块`
`#`
`def __init__(`
`self,in_channels,mid_channels,num_block`
`);`
` """`
`        参数:`
`            in_channels:`
`                输入通道数。`

`            mid_channels:`
`                中间特征通道数。`

`            num_blocks:`
`                残差块数量。`
`        """`
    `#初始化父类`
    
    `super().__init__()`
    `#用list存放网络层`
    `layers = []`
    `#输入卷积,吧in_channels变成mid_channels`
    `layers.append(`
`    nn.Conv2d(`
`                in_channels,`
`                mid_channels,`
`                kernel_size=3,`
`                stride=1,`
`                padding=1,`
`            )`
    `)`
    `#激活函数`
    `layers.append(`
    ` nn.LeakyReLU(`
     `negative_slope=0.1,inplace=True`
     `)`
    `)`
    `#堆叠多个残差块`
    `for _ in range(num_blocks):`
    `  layers.append(`
      `    ResidualBlockNoBN(`
          `   channels=mid_channels,`
          `)`
      `)`
`#把所有层组成一个Sequential`
`  self.main = nn.Sequential(*layers)`
`def forward(self, x):`
`        """`
`        前向传播。`
`        """`

`        # 直接把输入送进 Sequential`
`        return self.main(x)`
`class TinyFlowNet(nn.Module):`
`#非常简化的光流网络`
`#不是论文里的BasicVSR里的SpyNet 这是为了吧BasicVSR-like结构跑通`
`#输入 img_ref 参考帧,shape=[B,3,H,W]`
`img_supp`
` 支撑帧,相邻帧`
 `shape = [B,3,,H,W]`
 
 `def __init__(self, max_flow=20.0)`
 `参数,max_flow 限制预测光流的最大像素位移`
 `#初始化父类`
 `super().__init__()`
 `#保存最大光流范围`
 `self.max_flow = max_flo`
 `e#输入是两张RGB图拼接, 所有通道数是6`
 `self.body = nn.Sequential(`
 `   #第一层卷积,提取浅层特征`
` nn.Conv2d(6, 32, kernel_size=7, stride=1, padding=3),`
 `nn.LeakyReLU(0.1, inplace=True),`
 `#下采样一次,扩大感受`
` nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),`
` nn.LeakyReLU(0.1, inplace=True),`
` #中间卷积`
` nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#再下次阿阳一次,继续扩大感受`
`nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#中间卷积`
`nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#上采样回较高分辨率`
`nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True),`
`#再上采样回原图分辨率附近`
`nn.Conv2d(64, 96, kernel_size=3, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True)`
`#中间卷积`
`nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1),`
`nn.LeakyReLU(0.1, inplace=True)`
`#上采样回校高分辨率`
`nn.ConvTranspose2d(96, 64, kernel_size=4, stride=2, padding=1),`
`nn.LeakyReLU(0.1, inplace=True)`
`#再上采样回调分辨率附近`
`nn.ConvTranspose2d(64, 32, kernel_size=4,stride=2,padding=1)`
`nn.LeakyReLU(0.1, inplace=True)`
`#输出` `2通道光流,dx和dy`
`nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1)`
`)`        

`def` `forward(self, img_ref, img_supp):`
`##前向传播`
`#记录原始图像的高和宽`
`#shape` `[B, 6, H, W]`
`inp` `= torch.cat([img_ref, img_supp], dim=1)`
`#预测光流`
`flow` `=` `self.body(inp)`
`#如果因为下采样,上采样导致尺寸略有差异,就插值回原尺寸`
`if` `flow.shape[-2:]` `!=` `(h, w):`
   `flow` `=` `F.interpolate(`
      `flow,`
      `size(h,w),`
      `mode="bilinear",`
      `align_corners=False`
   `)`
`#用tanh` `把输出限制到[-1,1]`
`#再乘max_flow` `得到像素激光流范围`
`flow` `= torch.tanh(flow)` `*` `self.max_flow`
`#返回光流`
`return flow`

`class` `BasicVSRLite(nn.Module):`
`#教学版,BasicVSR-lite`
`#整体结构` `输入视频帧序列`
`#CNN提取每帧特征`
`#估计相邻帧光流`
`#反向时间传播T->1`
`正向时间传播1->T`
`当前帧特征` `+` `反向传播特征` `+` `正向传播特征` `融合`
`重建增强帧` `/` `超分帧`
`输入:`
`x` `shpe` `=` `[B, T, 3, H, W]`
`输出`
`scale` `=` `1:`
 `out shape=[B, T, 3, H, W]`
 `scale=2:`
  `out shape` `=` `[B, T, 3, 2H, 2W]`
  `scale` `= 4:`
  `out shape` `=` `[B, T, 3, 4h, 4w]`
`def` `__init__(`
  `self,` `mid_channels=64,`
  `num_feature_blocks=5,`
  `num_propagation_blocks=7,`
`scale=1,`
`max_flow=20.0,`
`)` `:`
`参数`
`mid_channels:中间特征通道`
  `64` `是常见的轻量配置`
`num_feature_blocks:`
 `每帧特征提取阶段的残差块数量`
`num_propagation_blocks`
 `正向/反向传播阶段的残差块数量`
 `num_reconstruction_blocks`
  `重建阶段的残差块数量`
`scale:`
 `放大倍率`
 `scale=1` `表示输入输出同尺寸,同于去噪,去模糊,增强`
 `scale=2` `表示2倍超分`
 `scale=4` `表示4倍超分`
 `max_flow` `TinyFlowNet` `预测光流最大像素位移`
 
 `#初始化分类`
 `super().__init__()`
 `#只允许1,2,4三种倍率`
 `assert` `scale` `in` `(1,2,4),` `"scale must be 1,2, or 4"`
 `#保存中间通道数`
 `self.mid_channels` `=` `mid_channels`
 `#保存超分倍率`
 `self.scale=scale`
 `#光流网络,用于估计相邻帧之间的运动`
 `self.flow_net` `=` `TinyFlowNet(`
    `max_flow` `=` `max_flow,`
 `)`
 `#每帧特征提取网络`
 `feature_layers=[]`
 `#第一层卷积,RGB图像3通道->mid_channels`
 `feature_layers.append(`
  `nn.Conv2d(`
  `3,` `mid_channels,kernel_size=3,`
   `stride=1,padding=1`
  `)`
 `)`
`#激活函数`
`feature_layers.append(`
  `nn.LeakyReLU(`
      `negative_slope=0.1,` `inplace=True`
  `)`
`)`
`#堆叠多个残差块,用于提取每一帧的空间特征`
`for` `_` `in` `range(num_feature_blocks):`
 `feature_layers.append(`
     `ResidualBlockNoBN(`
        `channels=mid_channels,`
     `)`
 `)`
`#组成每帧特征提取网络`
`self.feat_extract` `=` `nn.Sequential(*feature_layers)`
`#反向传播网络`
`#输入是当前帧特征` `+` `从未来帧传播过来的特征`
`#所以输入通道数mid_channels` `*` `2`
`self.backward_trunk` `= ResidualBlockWithInputConv(`
    `in_channels` `= mid_channels` `* 2,`
    `mid_channels` `=` `mid_channels,`
    `num_blocks=nm_propagation_blocks`
`)`
`#正向传播网络`
`#输入是当前帧特征` `+` `从过去帧传播过来的特征`
`#所以输入通道数也是mid_channels` `*2`
`self.forward_trunk` `= ResidualBlocksWithInputConv(`
   `in_channels` `= mid_channels` `*` `2,`
   `mid_channels` `=` `mid_channels,`
   `num_blocks` `= num_propagation_blocks,`
`)`
`#重建网络,`
`#输入是当前帧特征` `+` `反向传播特征` `+` `正向传播特征`
`#` `所以输入通道数是mid_channels` `*` `3`
`self.reconstruction` `= ResidualBlocksWithInputConv(`
  `in_channels` `= mid_channels` `*` `3,`
`mid_channels` `= mid_channels,`
`num_blocks` `=` `num_reconstruction_blocks,`
`)`
`#激活函数`
`self.lrelu` `= nn.LeakyReLU(`
   `negative_slope=0.1,`
   `inplace=True,`
`)`
`#pixelShuffle用于超分辨率上采样`
`self.pixel_shuffle = nn.PixelShuffle(`
  `upscale_factor=2,`
`)`
`#如果scale` `>=2` `需要一次2倍上采样`
`if` `scale` `>=` `2:`
  `self.upconv1` `=` `nn.Conv2d(`
     `mid_channels,`
     `mid_channels` `*` `4,`
     `kernel_size=3,`
     `stride=1,`
     `padding=1,`
  `)`
`#如果scale==4需要两次2倍上采样`
`if` `scale` `== 4:`
 `self.upconv2` `= nn.Conv2d(`
    `mid_channels,`
    `mid_channels` `*4,`
    `kernel_size=3,`
    `stride=1,`
    `padding=1,`
 `)`  
`#高分辨率空间上的卷积`
`self.conv_hr` `= nn.Conv2d(`
  `mid_channels,`
  `mid_channels,`
  `kernel_size=3,`
  `stride=1,`
  `padding=1,`
  `#最后一层卷积,把特征图变回RGB图像`
`)`
`self.conv_last` `= nn.Conv2d(`
  `mid_channels,`
 `3,`
 `kernel_size=3,`
 `stride=1,`
 `padding=1,`
`)`
`def` `compute_flows(self, x):`
`#计算相邻帧之间的光流`
`#输入`
`x shape =` `[B, T, 3, H ,W]`
`返回` `flow_forward:`
`用于正向传播`
`flows_forward[:, i-1]表示第i帧` `->` `第i-1帧的光流`
`用它可以把过去帧特征wrap到当前帧`
`shape=[B,T-1, 2, H, W]`
`flows_backward:`
 `用于反向传播,flows_backward[:,i]表示第i帧` `第i` `+ 1帧的光流`
 `用它可以把未来帧特征warp到当前帧`
 `shape` `= [B,T-1,2,H,W]`
 `#取出输入视频的维度`
 `b,t,c,h,w =` `x.size()`
 `#如果只有1帧,就没有相邻帧光流`
 `if t` `<=1:`
  `empty` `= x.new_zeros(b, 0, ,2, h, w)`
  `return empty, empty`

`#存放反向传播需要的光流`
`flows_backward` `=` `[]`
`#对于反向传播,需要从未来帧传播到的当前帧`
`#warp` `future` `feature` `到当前帧时,需要当前帧` `->未来帧` `的光流`
`for` `i in range(t - 1):`
`#计算第i帧到第i` `+` `1帧的光流`
`flow_i_to_next =` `self.flow_net(`
`x[:, i, :, :, :],`
`x[:, i + 1, :, :, :],`
`)`
`#保存光流`
`flows_backward.append(flow_i_to_next)`
`#把list堆叠成tensor`
`#shape` `= [B, T-1, 2, H, W]`
`flows_backward=torch.stack(`
   `flows_backward,`
   `dim=1,`
`)`
`#存档正向传播需要的光流`
`flows_forward=[]`
`#对于正向传播,需要从过去帧传播到当前帧`
`#warp` `past feature 到当前帧时,需要当前帧-》过去帧的光流`
`for i in range(1, t):`
`#计算第i帧到第i-1帧的光流`
 `flow_i_to_prev` `= self.flow_net(`
 `x[:, i, :, :, :],`
` x[:, i - 1, :, :, :],`
 `)`
 `#保存光流`
 `flows_forward.append(flow_i_to_prev)`
 `#shape = [B, T-1, 2, H, W]`
 `flow_forwards` `= torch.stack(`
 `flows_forward,`
 `dim=1`
 `)`
 `#返回正向传播光流和反向传播光流`
 `return` `flows_forward,` `flows_backward`
 
 `def` `upsample(self, feat):`
 `根据scale对重建特征进行上采样`
 `输入:` `feat shape` `= [B, C, H, W]`
 `输出` `scale=1`
  `out shape` `= [B, 3, H, W]`
  `scale=2`
  `out` `shape=[B, 3, 2H, 2W]`
  `scale=4:`
  `out` `shape=` `[B, 3, 4H, 4W]`
  `#如果是2倍或4倍超分,先做一次2倍PixelShuffl`
  `eif self.scale` `== 2:`
   `#卷积把通道扩展到4倍`
   `feat = self.upconv1(feat)`
   
   `#pixelShuffle把通道转换为空间分辨率`
   `feat = self.pixel_shuffle(feat)`
   `#激活`
   `feat =` `self.lrelu(feat)`
   `#如果是4倍超分,需要做两次2倍PixelShuffle`
   `elif self.scale ==` `4`
   `:#第一次2倍上采样`
   `feat` `= self.upconv1(feat)`
   `feat = self.pixel_shuffle(feat)`
   `feat = self.lrelu(feat)`

`#第二次2倍上采样`
`feat = self.upconv2(feat)`
`feat = self.pixel_shuffle(feat)`
`feat = self.lrelu(feat)`
`#高分辨率卷积`
`feat = self.conv_hr(feat)`
`#激活`
`feat = self.lrelu(fea)t`
`#输出RGB残差图像`
`out = self.conv_last(feat)`
`#返回输出`
`return out`

`def get_base_frame(self, lr_frame):`
`获取残差链接里的base image`
`#对于scale=1`
`base就是原输入帧`
`对于scale=2或scale=4`
`bas是双线性循环放大后的输入帧`
`最终输出`
`enhanced = predicted_residual` `+ base`
`#如果不做超分,直接返回原图`
`if self.scale` `== 1:`
  `return` `lr_frame`
  `#如果做超分,吧低清晰度输入双线性插值放大`
`base = F.interpolate(`
`lr_frame,`
`scale_factor=self.scale,`
`mode` `= "bilinear",`
`align_corners=False`
`)`
`#返回base` `frame`
`return base`

`def forward(self, x):`
` """`
`        前向传播。`

`        输入:`
`            x shape = [B, T, 3, H, W]`

`        输出:`
`            out shape = [B, T, 3, H*scale, W*scale]`
`        """`
`#检查输入必须是5倍`
`if x.dim()` `!= 5`
`#取出输入视频的维度`
`b,t,c,h,w` `= x.size()`
`检查必须是RGB视频`

`#每帧CNN特征提取`
`#把[B,T,3,H,W]reshape成[B*t, 3, H ,W]`
`#这样可以一次性把所有帧送进CNN`
`x_reshape = x.reshape(b*t, c, h, w)`
`#提取每帧空间特征`
`feats` `= self.feat_extract(x_reshape)`
`#把特征reshape回视频序列形式`
`#shape` `= [B,T,mid_channels, H,W]`
`feats =` `feats.reshape(`
`b,t,self.mid_channels,`
`h,w`
`)`
`#计算相邻帧光流`
`flows_forward用于正向传播`
`flows_backward` `用于反向传播`
`flows_forward, flows_backward` `= self.compute_flows(x)`
`#反向时间传播` `从T-1帧传播到第0帧`
`#用list存放每一帧反向传播特征`
` backward_feats = [None] * t`
 `#初始化传播特征为全0`
 `#shape =` `[B, mid_channels, H, W]`
 `feat_prop = x.new_zeros(b, self.mid_channels,h,w)`
`#从最后一帧住第一帧遍历`
`for i in range(t - 1, -1, -1):`
 `#如果不是最后一帧,就需要把未来帧传播特征warp到当前帧`
  `if` `i <` `t - 1:`
    `#flows_backward[:, i]是第i帧->第i+1帧的光流`
    `#用它可以把第i + 1帧的传播特征对齐到第i帧`
 ` feat_prop = flow_warp` `(`
      `feat_prop,`
      `flows_backward[:, i,:,:,:],`
 ` )`
`#当前帧特征`
`curr_feat` `= feats[L,i,:,:,:]`
`#拼接当前帧特征和传播特征`
`#shape` `= [B, mid_channel *2, H, W]`
`feat_input = torch.cat(`
   `[curr_feat, feat_prop],`
   `dim = 1,`
`)`
`#通过反向传播网络更新传播特征`
`#shape` `= [B, mid_channels * 2, H, W]`
`feat_input = torch.cat(`
   `[curr_feat, feat_prop],`
   `dim=1,`
`)`
`#通过反向传播网络更新传播特征`
`feat_prop = self.backward_trunk(feat_input)`
`#保存第i帧对应的反响传播特征`
`backward_feats[i]` `= feat_prop`
`#4` `正向时间传播,从第0帧传播到T-1帧`
`#用list存放每一帧的正向传播特征`
`forwards_feats =` `[None]*t`
`#初始化正向传播特征为全0`
`feat_prop = x.new_zeros(b,self.mid_channels, h,w)`

`#从第一帧往后一帧遍历`
`for i in range(t):`
`#如果不是第一帧,就需要把过去帧传播特征warp到当前帧`
   `if` `i > 0`
   `#dlows` `forward[:, i-1]是第i帧` `第i-1帧的光流`
   `#用它可以把第i-1帧的传播特征对其道第i帧`
   `feat_prop` `= flow_warp(`
       `feat_prop,`
       `flow_forward[:,i-1,:,:,:]`   
   `)`
   `#当前帧特征`
   `curr_feat = feats[:,i,:,:,:]`
   `#拼接当前帧特征和正向传播特征`
   `feat_input = torch.cat(`
    `[curr_feat, feat_prop],`
    `dim=1,`
   `)`
   `#通过正向传播网更新传播特征`
   `feat_prop = self.forward_trunk(feat_input)`
   `#保存第i帧对应的正向传播特征`
   `forward_feats][i = feat_prop`
   `#融合当前帧特征,反向传播特征,正向传播特征,` `并重建输出帧`
   `#存放所有输出帧`
   `outs=[]`
   `#对每一帧分别重建`
   `for i in` `range()t:`
   `#当前帧的原始空间特征`
    `curr_feat = feats[:,i,:,:,:]`
`   #当前帧的反向传播特征`
   `backward_feat =` `backward_feats[i]`
   `#当前帧的正向传播特征`
   `forward_feat = forward_feats[i]`
   `#三类特征拼接`
   `#shape = [B,mid_channels * 3, H,W]`
   `feat = torch.cat(`
     `[currefeat, backward_feat, forward_feat],`
     `dim=1,`
   `)`
   `#通过重建网络重建特征`
   `feat` `= self.reconstruction(feat)`
   `#根据scale输出RGB残差图像`
   `out = self.upsample(feat)`
   `#获取base frame`
   `#scale=1` `时就是原输入帧`
   `#scale= 2/4` `时是双线性插值放大后的输入帧`
   `base = self.get_base_frame(`
     `x[:,i,:,:,:]`
   `)`
   `#残差学习,最终输出` `网络预测残差` `+` `base`
   `out = out +` `base`
   `#保存当前输出帧`
   `outs.append(out)`
   `#把list里的每一帧堆叠回视频序列`
   `#shape` `= [B,T,3, H *scale, W*scale]`
   `outs = torch.stack(outs, dim=1)`
   `#返回增强后的视频帧序列`
   `return outs`
   
   `if __name__ == "__main__":`
   `简单测试代码`
   `直接运行` `python basicvsr_lite.py`
   `#构造一个BasicVSR-lite模型`
   `#scale=1` `表示输入输出同分辨率`
   `model = BasicVSRLite(`
    `mid_channels=64,`
    `num_feature_blocks=5,`
    `num_propagation_blocks=7,`
    `num_reconstruction_blocks=10,`
    `scale=1,`
   `)`
   `#构造一个假的输入视频batch`
   `#B=2,T=7,C=3,H=64,W=64`
   `x = torch.randn(2,7,3,64,64)`
`#前向传播`
`y = model(x)`
`#打印输入输出尺寸`
`#训练时一般这样计算损失`
`#假设gt是清晰视频帧,shape和y一样`
`gt = torch.randn_like(y)`
`#视频增强/超分常用L1Loss`
`loss = F.l1_loss(y, gt)`
`

二 模型总结

复制代码
这个模型是一个教学版BsicVSR-lite视频增强模型。核心目的不会单张图片,而是处理一段连续视频帧,利用前后帧信息增强当前帧。`

`输入一段低质量视频帧`
`提取每一帧的CNN特征`
`估计相邻之间的光流`
`用光流吧前后帧特征对齐`
`做正向和反向时序信息传播`
`融合当前帧,过去帧,未来帧信息。`
`1、这个模型能做什么`
`视频去噪`
`视频去模糊`
`视频压缩伪影修复`
`视频画质增强`
`视频超分辩率`
`由scale控制任务类型`
`scale=1`
`表示输入输出同尺寸,适合葡萄视频增强`
`scale=2`
`表示2倍视频超分`
`scale=4`
`表示4倍视频超分`

`2输入输出格式`
`x.shape =` `[B,T,3,H,W]`
`含义是`
`B` `= batch_size` `一次训练几个视频片段`
`T` `=` `每个视频片段有多少帧`
`3` `= RGB通道数量`
`H` `图像高度`
`W` `图像宽度`
`x` `= torch.randn(2, 7, 3, 64, 64)`
`表示`
`2` `个视频片段`
`每个片段7帧`
`每帧是RGB图像`
`每帧大小64x64`
`如果scale=1` `输出是`
`y.shape =` `[B,T,3,H,W]`
`如果scale=4` `输出是`
`y.shape=[B,T,3,4H,4W]`
`3` `模型整体网络结构`
`class` `BasicVSRLite(nn.Module)`
`里面主要有这些模块`
`feat_extract` `每帧CNN特征提取`
`flow_net` `简化光流估计网络`
`backward_trunk` `反向时间传播网络`
`forward_trunk 正向时间传播网络`
`reconstruction` `特征融合和重建网络`
`upsample` `超分上采样模块`
`residual` `base残差输出链接`

`整体结构可以画成`
`输入视频x:[B,T,3,H,W]`
`每帧CNN特征提取`
`feats` `[B,T,64,H,W]`
`计算相邻帧光流`
`反向传播,T-1->0` `得到未来信息backward_feats`
`正向传播` `0->T-1` `得到过去信息forward_feats`
`每一帧融合`
`当前帧特征` `过去信息` `未来信息`
`重建网络`
`输出增强视频`

`4` `flow_warp是干什么的?`
`def` `flow_warp(x, flow)`
`作用是,根据光流把特征图进行空间对齐`
`视频里物体会运动,`
`第一帧` `人脸在左边`
`第二帧,人脸在中间`
`第三针,人脸在右边`
`直接把这些帧的特征融合,会无法对齐,结果容易模糊`
`需要先用光流估计运动`
`这个额像素从上一帧移动到了哪里`
`这个特征应该往左还是往右移动`
`F.grid_sample(...)`
`重新采样特征图,吧前后特征对齐到当前帧`
`flow_warp` `=` `根据运动信息移动特征图`
`5` `ResidualBlockNoBN是什么`
`这是一个不带BatchNorm的残差块`
`class ResidualBlockNoBN(nn.Module)`
`结构是`
`输入x`
`Conv` `3x3`
`LeakyReLU`
`Conv` `3x3`
`加回输入x`
`输出`
`对应代码`
`identity=x`
`out =` `self.conv1()x`
`out =` `self.relu(out)`
`out` `= self.conv2(out)`
`out =` `identity` `+` `out *` `self.res_scale`
`作用是增强特征表达能力`
`为什么没有BatchNorm 归一化一下?`
`因为视频增强超分,去噪这类图像恢复任务需要保留非常惊喜的像素信息,BatchNorm有时会破坏图像的亮度,颜色,纹理分布,所以很多恢复模型不用BN`
`6` `TinyFlowNet是什么?`
`class TinyFlowNet(nn.Module)`
`这是一个简化版本光流估计网络`
`输入两帧图像`
`img_ref.shape=[B,3,H,W]`
`img_supp.shape` `= [B,3, H,W]`
`先拼接成`
`[B,6,H,W]`
`然后经过一个小CNN,输出`
`flow.shape = [B,2,H,W]`
`其中`
`flow[:, 0,:,)]` `= x方向位移`
`flow]:,1,:,:[` `= y` `方向位移`
`这个TInyFlowNet是教学版,不是官方BasicVSR里的SpyNet,`
`7每帧特征提取feat_extract`
`这一段代码`
`self.feat_extract` `= nn.Sequential(*feature_layers)`
`作用把每一帧RGB图像变成CNN特征图`
`输入一帧`
`[B,3,H,W]`
`输出` `[B,64,H,W]`
`mid_channels =` `64`
`所以每帧被转换成64通道的空间特征`
`注意这里不是输出一个向量,保留二维空间结构`

`图片` `CNN` `512维向量`
`图片` `CNN` `64通道特征图`
`因为视频增强最终需要回复图像,不能把空间信息压扁。`
`8` `compute` `flows计算什么?`
`flows_forward, flows_backward = self.compute_flos(x)`
`这个函数计算相邻帧之间的光流`
`flows_backward`
`flows_forward`

`flows_backward用于反向传播,就是从未来帧往当前帧传信息`
`代码里面计算是这样的`
`flow_i_to_next = self.flow_net` `(`
   `x[:, i]`
   `x[:, i + 1]`
`)`
`用它可以把低i+1帧的特征warp到第i帧。`
`flows_forward`
`用于正向传播,从过去帧往当前帧传信息`
`flow_i_to_prev` `= self.flow_net(`
`x[:, i]`
`   x[:, i + 1]`
`)`
`含义是`
`第i帧到第i+1帧的光流`
`9反向时间传播`
`for i in` `range(t-1, -1, -1)`
`从最后一帧往第一帧处理`
`T-1,` `T-2 ,0`
`目的`
`让当前帧获得未来帧的信息`
`例如处理第3帧时,利用第4,5,6帧传过来的信息`
`初始化feat_prop = 0`
`从最后一帧开始`
 `如果不是最后一帧`
   `用光流warp` `未来帧传来的feat_prop`
   `拼接`
     `当前帧特征curr_feat`
   `未来传播特征` `feat_prop`
   `送入backward_trunk`
   `得到新的feat_prop`
   `保存为backward_feats[i]`
`核心代码`
`feat_prop =` `flow_warp` `(`
 `feat_prop,flows_backward[:,i]`
`)`
`feat_input= torch.cat(`
`]curr_feat, feat_prop[`
`dim=1`
`)`
`feat_prop=self.backward_trunk(feat_input)`
`10` `正向时间传播`
`for i in range(t)`
`例如处理第3帧时,可以利用第0,1,2帧传过来的信息`
`初始化feat_prop` `= 0`
`从第一帧开始`
 `如果不是第一帧`
   `用光流warp过去帧传来的feat_prop`
   `拼接`
    `当前帧特征curr_feat`
     `过去传播特征` `feat_prop`
      `送入forward_tru`
      `nk得到新的feat_prop`
  `保存为forward_feats[i]`
  `核心代码`
  `feat_prop = flow_warp(`
`    feat_prop,`
`    flows_forward[:, i - 1]`
`)`

`feat_input = torch.cat(`
`    [curr_feat, feat_prop],`
`    dim=1`
`)`

`feat_prop = self.forward_trunk(feat_input)`
`11 最终融合重建`
`1curr_feat当前帧呢自己的特征`
`2` `backward_feat 未来帧传来的信息`
`3` `forward_feat过去帧传来的信息`
`然后拼接`
`feat = torch.cat([curr_feat, backward_feat, forward_feat],`
`dim=1)`
`如果mid_channels=64` `那么`
`curr_feat` `= 6通道`
`backward_feat = 64通道`
`forward_feat = 64通道`
`拼接后` `=` `192通道`
`送入重建网络`
`feat = self.reconstruction(fe)at`
`重建网络把192通道融合成64通道,`
`out = self.upsample(feat)输出RGB图像`
`12 upsample是什么`
`def upsample(self, feat)`
`负责吧特征图转换成最终RGB输出`
`如果scale=1`
`不做放大`
`[B,64,H,W]->[B,3,H,W]`
`如果scale=2`
`Conv -> PixelShuffle x2 -> Conv -> RGB`
`[B,64,H,W] -> [B,3,2H,2W]`
`如果scale=4`
`Conv -> PixelShuffle x2`
`Conv -> PixelShuffle x2`
`Conv -> RGB`

`[B,64,H,W] -> [B,3,4H,4W]`
`PixelShuffle是超分模型里常用的上采样方式,通道维的信息重新排列空间维上。`
`13 为什么最后要out + base`
`base = self.get_base_frame(x[:, i])`
`out = out + base`
`残差学习`
`如果scale` `= 1`
`base` `= 原始输入帧`
`输出=原始输入帧` `+ 模型预测的修正量`
`如果scale=4`
`base = 双线性插值方法后的输入帧`
`输出=放大后的输入帧` `+` `模型预测的高频细节`
`因为模型从不从零生成整张图,`
`哪里需要去噪` `哪里需要变清晰` `哪里需要补充细节`  `哪里需要修复伪影`
`14 forward函数完整流程`
`你的forward可以总结成`
`1检查输入维度`
`2 把视频帧展开`
 `[B,T,3,H,W]->]B*T,3,H,W[`
`3 每帧CNN特征提取`
`]B*T, 3,H,W[->[B*T,64,H,W]`
`4` `reshape回视频形式`
`[B*T,64,H,W]->[B,T,64,H,W]`
`5` `计算相邻帧光流`
`flows_forward = [B,T-1,2,H,W]`
`flows_backward=[B,T-1,2,H,W]`
`6` `反向时间传播`
 `得到backward_feats`
 `7 正向时间传播`
 `得到forward_feat`
 `s8 对每一帧`
 `当前帧呢特征` `未来信息` `+过去信息`
 `reconstruction`
 `upsample`
 `加base` `frame`
 `输出增强帧`
 `9` `把所有输出帧` `stack`
 `[B,T,3,H*scale,W*scale]`
 `15 模型训练时怎么用`
 `普通视频增强`
 `model` `= BasicVSRLite(scale=1).cuda()`
 `lq = torch.randn(2,7,3,128,128).cuda(`
 `)gt = torch.randn(2,7,3,128,128).cuda()`
 `pred = model(lq)`
 `loss = F.l1_loss(pred, gt)`
 `loss.backward()`
 `倍视频超分`
 `model = BasicVSRLite(scale=4).cuda()`
`lq = torch.randn(2, 7, 3, 64, 64).cuda()`
`gt = torch.randn(2, 7, 3, 256, 256).cuda()`
`pred = model(lq)`
`loss = F.l1_loss(pred, gt)`
`loss.backward()`
`其中,`
`lq = low quality低质量视频帧`
`gt = ground truth` `高质量视频帧`
`16这个模型的学习重点`
`16.1 CNN特征提取`
`self.feat_extract`
`16.2` `光流对齐`
`TinyFlowNet` `+ flow_warp`
`16.3 双向时间传播`
`backwrd_trunk + forward_trunk`
`16.4` `融合重建`
`reconstruction` `+ upsample + out +base`
`当前帧不清楚的地方,可以从前后帧里找信息补回来`
`17 需要注意的地方`
`1. TinyFlowNet` `光流效果不一定好`
`2. 没有使用预训练SpyNet /` `PWC`
`Net美元BasicVSR+的二阶传播`
`4没有EDVR的可变性卷积对齐`
`5` `只用了基础L1Loss视觉锐度可能一般`



































`
相关推荐
huangdong_1 小时前
1688商品图片采集技术解析:登录态处理与SKU图自动分类
开发语言
chase_my_dream1 小时前
C++ + SLAM 高频面试问题整理
开发语言·c++·面试
Cloud_Shy6181 小时前
解读《Effective Python 3rd Edition》:从练气到老魔(第五章 Item 30 - 32)
开发语言·人工智能·笔记·python·学习方法
天佑木枫2 小时前
15天Python入门系列 · 序
开发语言·python
宋拾壹3 小时前
同时添加多个类目
android·开发语言·javascript
凡人叶枫3 小时前
Effective C++ 条款04:确定对象被使用前已先被初始化
java·linux·开发语言·c++·嵌入式开发
小小龙学IT4 小时前
Go 语言后端开发:从并发模型到生产落地的工程实践
开发语言·后端·golang
ytttr8734 小时前
Qt 数字键盘实现
开发语言·qt
wearegogog1234 小时前
C# .NET 文件比较工具 WinForms
开发语言·c#·.net
再写一行代码就下班4 小时前
Cursor配置Java环境、创建Spring Boot项目的步骤
java·开发语言·spring boot