目录

(即插即用模块-特征处理部分) 十九、(NeurIPS 2023) Prompt Block 提示生成 / 交互模块

文章目录

  • 1、Prompt Block
  • 2、代码实现

paper:PromptIR: Prompting for All-in-One Blind Image Restoration

Code:https://github.com/va1shn9v/PromptIR


1、Prompt Block

在解决现有图像恢复模型时,现有研究存在一些局限性: 现有的图像恢复模型通常针对特定的退化类型(如去噪、去雾、去雨)进行训练,这会缺乏泛化能力,难以适应多种退化类型和级别。此外,现有的多退化图像恢复模型通常需要知道输入图像的退化类型,才能选择合适的模型进行恢复,这在实际应用中都是不太现实的。最后,现有的多退化图像恢复模型需要为每种退化类型和级别训练单独的模型,这会导致训练负担过重,且难以在资源受限的平台(如移动设备和边缘设备)上部署。为此,这篇论文提出一种 Prompt Block,其通过引入可学习的提示参数,将退化相关的信息编码到网络中,从而引导网络进行自适应的图像恢复。

Prompt Block 可以分为两个部分:即 Prompt Generation Module(PGM)Prompt Interaction Module(PIM)。具体来说,PGM 的目标是根据输入图像的特征动态生成 prompt 参数,使其能够更好地适应不同的退化类型。而 PIM 通过将 prompt P 与输入特征沿通道维度进行拼接,然后通过 Transformer block 进行处理,实现特征与 prompt 的交互。

对于一个输入特征 X,Prompt Block 的实现过程:

Prompt Generation Module:

  1. 对输入特征进行全局平均池化 (GAP),得到特征向量 v。
  2. 使用 1x1 卷积层对特征向量进行降维,得到紧凑的特征向量。
  3. 对降维后的特征向量进行 softmax 操作,得到 prompt 权重 w。
  4. 使用 prompt 权重 w 对 prompt 组件 Pc 进行加权求和,得到输入条件 prompt P。

Prompt Interaction Module:

  1. 首先将 prompt P 与输入特征 Fl 沿通道维度进行拼接。
  2. 将拼接后的特征通过 Transformer block 进行处理。
  3. 最后将特征经两层卷积处理,输出特征即为经过 Prompt Block 调整后的特征。

Prompt Generation / Interaction Module 结构图:

2、代码实现

python 复制代码
import torch
from torch import nn, einsum
import torch.nn.functional as F


class PromptGenBlock(nn.Module):
    def __init__(self, prompt_dim, prompt_len=5, prompt_size=96, lin_dim=192):
        super(PromptGenBlock, self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size))
        self.linear_layer = nn.Linear(lin_dim, prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        emb = x.mean(dim=(-2, -1))
        prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
                                                                                                                  1, 1,
                                                                                                                  1,
                                                                                                                  1).squeeze(
            1)
        prompt = torch.sum(prompt, dim=1)
        prompt = F.interpolate(prompt, (H, W), mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt


if __name__ == '__main__':
    x = torch.randn(4, 3, 64, 64).cuda()
    model = PromptGenBlock(3, lin_dim=3).cuda()
    out = model(x)
    print(out.shape)
本文是转载文章,点击查看原文
如有侵权,请联系 xyy@jishuzhan.net 删除
相关推荐
东木月3 分钟前
基于金融产品深度学习推荐算法详解【附源码】
深度学习·金融·推荐算法
明朝百晓生1 小时前
【PyTorch][chapter-35][MLA]
人工智能·深度学习·transformer
cyong8884 小时前
深度学习中的向量的样子-DCN
人工智能·深度学习
Python数据分析与机器学习5 小时前
《基于深度学习的高分卫星图像配准模型研发与应用》开题报告
图像处理·人工智能·python·深度学习·神经网络·机器学习
从零开始学习人工智能6 小时前
深度学习模型压缩:非结构化剪枝与结构化剪枝的定义与对比
人工智能·深度学习·剪枝
go54631584657 小时前
在办公电脑上本地部署 70b 的 DeepSeek 模型并实现相应功能的大致步骤
深度学习
一个处女座的程序猿O(∩_∩)O8 小时前
人工智能中神经网络是如何进行预测的
人工智能·深度学习·神经网络
小白的高手之路8 小时前
如何安装旧版本的Pytorch
人工智能·pytorch·python·深度学习·机器学习·conda
一头大学牲8 小时前
NN:神经网络
人工智能·深度学习·神经网络
earthzhang20218 小时前
《Python深度学习》第四讲:计算机视觉中的深度学习
人工智能·python·深度学习·算法·计算机视觉·numpy·1024程序员节