本文是对DISTS图像质量评价指标的代码解读,原文解读请看DISTS文章讲解。
本文的代码来源于IQA-Pytorch工程。
1、原文概要
以前的一些IQA方法对于捕捉纹理上的感知一致性有所欠缺,鲁棒性不足。基于此,作者开发了一个能够在图像结构和图像纹理上都具有与人类相同感知判断的指标,在此之上,还希望纹理能够resample(不需要像素级对齐)之后也是一样的,另外区分开退化(JPEG,JPEG会损失纹理)。实现该指标可以分为4个步骤:
- 对图像进行一个初始的变换,从像素空间变换到特征空间。
- 对特征提取所谓纹理的表示,对特征提取所谓结构的表示。
- 利用纹理和结构的表示,加入一些可学习的权重综合计算一个评价指标。
- 利用这个评价指标,进一步优化权重得到纹理区域resample不敏感的指标,且能够有结构和纹理上做感知相似度的模型。
实现后的指标作为优化指标对比其他IQA指标有明显优势,如下图所示。
2、代码结构
代码实现位于pyiqa/archs/dists_arch.py中:
3 、核心代码模块
L2pooling
类
这个类实现了我们前面提到的预处理部分替换max-pool的操作。
python
class L2pooling(nn.Module):
def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
super(L2pooling, self).__init__()
self.padding = (filter_size - 2) // 2
self.stride = stride
self.channels = channels
a = np.hanning(filter_size)[1:-1]
g = torch.Tensor(a[:, None] * a[None, :])
g = g / torch.sum(g)
self.register_buffer(
'filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1))
)
def forward(self, input):
input = input**2
out = F.conv2d(
input,
self.filter,
stride=self.stride,
padding=self.padding,
groups=input.shape[1],
)
return (out + 1e-12).sqrt()
这里可以看到前向的过程中作者先是进行了一个平方,然后使用了一个self.filter的滤波器,kernel_size为3的hanning窗,stride=2,且是一个深度可分离的卷积,groups与输入通道一致,这代替max-pool完成了一次抗混叠的下采样,最后进行一个sqrt,这与讲解中展示的公式一致,如下所示:
P ( x ) = g ∗ ( x ∗ x ) P(x)=\sqrt{g*(x*x)} P(x)=g∗(x∗x) 这个 g g g在初始化时被复制了self.channels次,实际它一个通道的数值,读者可以打印如下所示:
0.0625 0.125 0.0625 0.125 0.25 0.125 0.0625 0.125 0.0625 \] \\begin{bmatrix} 0.0625 \& 0.125 \& 0.0625 \\\\ 0.125 \& 0.25 \& 0.125 \\\\ 0.0625 \& 0.125 \& 0.0625 \\end{bmatrix} 0.06250.1250.06250.1250.250.1250.06250.1250.0625 一个典型的低通滤波器,做了一个空间上根据距离的平均。 #### `DISTS` 类 存放着跟实际计算指标相关的代码。 ```python @ARCH_REGISTRY.register() class DISTS(torch.nn.Module): r"""DISTS model. Args: pretrained_model_path (String): Pretrained model path. """ def __init__(self, pretrained=True, pretrained_model_path=None, **kwargs): """Refer to official code https://github.com/dingkeyan93/DISTS""" super(DISTS, self).__init__() vgg_pretrained_features = models.vgg16(weights='IMAGENET1K_V1').features self.stage1 = torch.nn.Sequential() self.stage2 = torch.nn.Sequential() self.stage3 = torch.nn.Sequential() self.stage4 = torch.nn.Sequential() self.stage5 = torch.nn.Sequential() for x in range(0, 4): self.stage1.add_module(str(x), vgg_pretrained_features[x]) self.stage2.add_module(str(4), L2pooling(channels=64)) for x in range(5, 9): self.stage2.add_module(str(x), vgg_pretrained_features[x]) self.stage3.add_module(str(9), L2pooling(channels=128)) for x in range(10, 16): self.stage3.add_module(str(x), vgg_pretrained_features[x]) self.stage4.add_module(str(16), L2pooling(channels=256)) for x in range(17, 23): self.stage4.add_module(str(x), vgg_pretrained_features[x]) self.stage5.add_module(str(23), L2pooling(channels=512)) for x in range(24, 30): self.stage5.add_module(str(x), vgg_pretrained_features[x]) for param in self.parameters(): param.requires_grad = False self.register_buffer( 'mean', torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1) ) self.register_buffer( 'std', torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1) ) self.chns = [3, 64, 128, 256, 512, 512] self.register_parameter( 'alpha', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) ) self.register_parameter( 'beta', nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)) ) self.alpha.data.normal_(0.1, 0.01) self.beta.data.normal_(0.1, 0.01) if pretrained_model_path is not None: load_pretrained_network(self, pretrained_model_path, False) elif pretrained: load_pretrained_network(self, default_model_urls['url'], False) def forward_once(self, x): h = (x - self.mean) / self.std h = self.stage1(h) h_relu1_2 = h h = self.stage2(h) h_relu2_2 = h h = self.stage3(h) h_relu3_3 = h h = self.stage4(h) h_relu4_3 = h h = self.stage5(h) h_relu5_3 = h return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] def forward(self, x, y): r"""Compute IQA using DISTS model. Args: - x: An input tensor with (N, C, H, W) shape. RGB channel order for colour images. - y: An reference tensor with (N, C, H, W) shape. RGB channel order for colour images. Returns: Value of DISTS model. """ feats0 = self.forward_once(x) feats1 = self.forward_once(y) dist1 = 0 dist2 = 0 c1 = 1e-6 c2 = 1e-6 w_sum = self.alpha.sum() + self.beta.sum() alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) beta = torch.split(self.beta / w_sum, self.chns, dim=1) for k in range(len(self.chns)): x_mean = feats0[k].mean([2, 3], keepdim=True) y_mean = feats1[k].mean([2, 3], keepdim=True) S1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1) dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) xy_cov = (feats0[k] * feats1[k]).mean( [2, 3], keepdim=True ) - x_mean * y_mean S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) score = 1 - (dist1 + dist2) return score.squeeze(-1).squeeze(-1) ``` 3个重点如下: 1. 初始化中首先会插入前面讲到的L2_Pooling,来替换原始的max-pool,其他的就是初始化必要的标准化变量和用于各层结构和纹理的加权系数 α \\alpha α和 β \\beta β,最后导入预训练的网络即可。 2. 前向中调用的forward_once,可以看到总共有6个输出,第一个输出是输入x,即我们讲解中提到的identity的变换,其他5层是事先定义好的输出位置。 3. dists的计算:首先根据权重的大小对alpha和beta进行归一化,随后分层计算我们前面定义好的纹理特征和结构特征的相关性公式,针对于纹理的部分代码中是S1,可以看到S1是利用了特征的在空间上的均值计算的参考图像和待评估图像的相关系数,然后利用alpha对计算好的S1进行加权,得到纹理上相似度dist1;针对于结构的部分代码中是S2,S2是利用了参考图像和待评估图像两个特征的协方差和方差,由于是全局的窗口所以在计算后会求取空间上的一个均值,这样得到了结构上的相似度dist2。最后结合dist1和dist2得到最终的score。dists计算的公式如下,可以对照着公式来查看: l ( x \~ j ( i ) , y \~ j ( i ) ) = 2 μ x \~ j ( i ) μ y \~ j ( i ) + c 1 ( μ x \~ j ( i ) ) 2 + ( μ y \~ j ( i ) ) 2 + c 1 l(\\tilde{x}_j\^{(i)}, \\tilde{y}_j\^{(i)}) = \\frac{2\\mu_{\\tilde{x}_j}\^{(i)}\\mu_{\\tilde{y}_j}\^{(i)} + c_1}{(\\mu_{\\tilde{x}_j}\^{(i)})\^2 + (\\mu_{\\tilde{y}_j}\^{(i)})\^2 + c_1} l(x\~j(i),y\~j(i))=(μx\~j(i))2+(μy\~j(i))2+c12μx\~j(i)μy\~j(i)+c1 s ( x \~ j ( i ) , y \~ j ( i ) ) = 2 σ x \~ j y \~ j ( i ) + c 2 ( σ x \~ j ( i ) ) 2 + ( σ y \~ j ( i ) ) 2 + c 2 , s(\\tilde{x}_j\^{(i)}, \\tilde{y}_j\^{(i)}) = \\frac{2\\sigma_{\\tilde{x}_j\\tilde{y}_j}\^{(i)} + c_2}{(\\sigma_{\\tilde{x}_j}\^{(i)})\^2 + (\\sigma_{\\tilde{y}_j}\^{(i)})\^2 + c_2}, s(x\~j(i),y\~j(i))=(σx\~j(i))2+(σy\~j(i))2+c22σx\~jy\~j(i)+c2, D ( x , y ; α , β ) = 1 − ∑ i = 0 m ∑ j = 1 n i ( α i j l ( x \~ j ( i ) , y \~ j ( i ) ) + β i j s ( x \~ j ( i ) , y \~ j ( i ) ) ) D(x, y; \\alpha, \\beta) = 1 - \\sum_{i = 0}\^{m} \\sum_{j = 1}\^{n_i} \\left( \\alpha_{ij} l(\\tilde{x}_j\^{(i)}, \\tilde{y}_j\^{(i)}) + \\beta_{ij} s(\\tilde{x}_j\^{(i)}, \\tilde{y}_j\^{(i)}) \\right) D(x,y;α,β)=1−i=0∑mj=1∑ni(αijl(x\~j(i),y\~j(i))+βijs(x\~j(i),y\~j(i)))其中, l l l和 s s s分别代表纹理和结构。 ## 3、总结 代码实现核心的部分讲解完毕,DISTS作为一个可以同时捕获结构和纹理相似度的全参考IQA指标,在很多比赛和论文的引用中都可以见到它的身影,实用性是毋庸置疑的。 大家有涉及到数据集筛选、纹理分类、纹理搜索类的任务可以尝试使用DISTS指标,或者是在算法评估中利用它来做一个方面的对比评估。 *** ** * ** *** **感谢阅读,欢迎留言或私信,一起探讨和交流。 如果对你有帮助的话,也希望可以给博主点一个关注,感谢。**