(即插即用模块-特征处理部分) 十九、(NeurIPS 2023) Prompt Block 提示生成 / 交互模块

文章目录

  • 1、Prompt Block
  • 2、代码实现

paper:PromptIR: Prompting for All-in-One Blind Image Restoration

Code:https://github.com/va1shn9v/PromptIR


1、Prompt Block

在解决现有图像恢复模型时,现有研究存在一些局限性: 现有的图像恢复模型通常针对特定的退化类型(如去噪、去雾、去雨)进行训练,这会缺乏泛化能力,难以适应多种退化类型和级别。此外,现有的多退化图像恢复模型通常需要知道输入图像的退化类型,才能选择合适的模型进行恢复,这在实际应用中都是不太现实的。最后,现有的多退化图像恢复模型需要为每种退化类型和级别训练单独的模型,这会导致训练负担过重,且难以在资源受限的平台(如移动设备和边缘设备)上部署。为此,这篇论文提出一种 Prompt Block,其通过引入可学习的提示参数,将退化相关的信息编码到网络中,从而引导网络进行自适应的图像恢复。

Prompt Block 可以分为两个部分:即 Prompt Generation Module(PGM)Prompt Interaction Module(PIM)。具体来说,PGM 的目标是根据输入图像的特征动态生成 prompt 参数,使其能够更好地适应不同的退化类型。而 PIM 通过将 prompt P 与输入特征沿通道维度进行拼接,然后通过 Transformer block 进行处理,实现特征与 prompt 的交互。

对于一个输入特征 X,Prompt Block 的实现过程:

Prompt Generation Module:

  1. 对输入特征进行全局平均池化 (GAP),得到特征向量 v。
  2. 使用 1x1 卷积层对特征向量进行降维,得到紧凑的特征向量。
  3. 对降维后的特征向量进行 softmax 操作,得到 prompt 权重 w。
  4. 使用 prompt 权重 w 对 prompt 组件 Pc 进行加权求和,得到输入条件 prompt P。

Prompt Interaction Module:

  1. 首先将 prompt P 与输入特征 Fl 沿通道维度进行拼接。
  2. 将拼接后的特征通过 Transformer block 进行处理。
  3. 最后将特征经两层卷积处理,输出特征即为经过 Prompt Block 调整后的特征。

Prompt Generation / Interaction Module 结构图:

2、代码实现

python 复制代码
import torch
from torch import nn, einsum
import torch.nn.functional as F


class PromptGenBlock(nn.Module):
    def __init__(self, prompt_dim, prompt_len=5, prompt_size=96, lin_dim=192):
        super(PromptGenBlock, self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size))
        self.linear_layer = nn.Linear(lin_dim, prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        emb = x.mean(dim=(-2, -1))
        prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
                                                                                                                  1, 1,
                                                                                                                  1,
                                                                                                                  1).squeeze(
            1)
        prompt = torch.sum(prompt, dim=1)
        prompt = F.interpolate(prompt, (H, W), mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt


if __name__ == '__main__':
    x = torch.randn(4, 3, 64, 64).cuda()
    model = PromptGenBlock(3, lin_dim=3).cuda()
    out = model(x)
    print(out.shape)
相关推荐
童话名剑16 分钟前
锚框 与 完整YOLO示例(吴恩达深度学习笔记)
笔记·深度学习·yolo··anchor box
wyw000029 分钟前
目标检测之Faster R-CNN
计算机视觉
Hcoco_me2 小时前
大模型面试题62:PD分离
人工智能·深度学习·机器学习·chatgpt·机器人
OpenCSG2 小时前
AgenticOps 如何重构企业 AI 的全生命周期管理体系
大数据·人工智能·深度学习
All The Way North-3 小时前
RNN基本介绍
rnn·深度学习·nlp·循环神经网络·时间序列
yatingliu20193 小时前
将深度学习环境迁移至老旧系统| 个人学习笔记
笔记·深度学习·学习
撬动未来的支点3 小时前
【AI】光速理解YOLO框架
人工智能·yolo·计算机视觉
kebijuelun3 小时前
REAP the Experts:去掉 MoE 一半专家还能保持性能不变
人工智能·gpt·深度学习·语言模型·transformer
ldccorpora3 小时前
Multiple-Translation Arabic (MTA) Part 2数据集介绍,官网编号LDC2005T05
人工智能·深度学习·自然语言处理·动态规划·语音识别