(即插即用模块-特征处理部分) 十九、(NeurIPS 2023) Prompt Block 提示生成 / 交互模块

文章目录

  • 1、Prompt Block
  • 2、代码实现

paper:PromptIR: Prompting for All-in-One Blind Image Restoration

Code:https://github.com/va1shn9v/PromptIR


1、Prompt Block

在解决现有图像恢复模型时,现有研究存在一些局限性: 现有的图像恢复模型通常针对特定的退化类型(如去噪、去雾、去雨)进行训练,这会缺乏泛化能力,难以适应多种退化类型和级别。此外,现有的多退化图像恢复模型通常需要知道输入图像的退化类型,才能选择合适的模型进行恢复,这在实际应用中都是不太现实的。最后,现有的多退化图像恢复模型需要为每种退化类型和级别训练单独的模型,这会导致训练负担过重,且难以在资源受限的平台(如移动设备和边缘设备)上部署。为此,这篇论文提出一种 Prompt Block,其通过引入可学习的提示参数,将退化相关的信息编码到网络中,从而引导网络进行自适应的图像恢复。

Prompt Block 可以分为两个部分:即 Prompt Generation Module(PGM)Prompt Interaction Module(PIM)。具体来说,PGM 的目标是根据输入图像的特征动态生成 prompt 参数,使其能够更好地适应不同的退化类型。而 PIM 通过将 prompt P 与输入特征沿通道维度进行拼接,然后通过 Transformer block 进行处理,实现特征与 prompt 的交互。

对于一个输入特征 X,Prompt Block 的实现过程:

Prompt Generation Module:

  1. 对输入特征进行全局平均池化 (GAP),得到特征向量 v。
  2. 使用 1x1 卷积层对特征向量进行降维,得到紧凑的特征向量。
  3. 对降维后的特征向量进行 softmax 操作,得到 prompt 权重 w。
  4. 使用 prompt 权重 w 对 prompt 组件 Pc 进行加权求和,得到输入条件 prompt P。

Prompt Interaction Module:

  1. 首先将 prompt P 与输入特征 Fl 沿通道维度进行拼接。
  2. 将拼接后的特征通过 Transformer block 进行处理。
  3. 最后将特征经两层卷积处理,输出特征即为经过 Prompt Block 调整后的特征。

Prompt Generation / Interaction Module 结构图:

2、代码实现

python 复制代码
import torch
from torch import nn, einsum
import torch.nn.functional as F


class PromptGenBlock(nn.Module):
    def __init__(self, prompt_dim, prompt_len=5, prompt_size=96, lin_dim=192):
        super(PromptGenBlock, self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size))
        self.linear_layer = nn.Linear(lin_dim, prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        emb = x.mean(dim=(-2, -1))
        prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
                                                                                                                  1, 1,
                                                                                                                  1,
                                                                                                                  1).squeeze(
            1)
        prompt = torch.sum(prompt, dim=1)
        prompt = F.interpolate(prompt, (H, W), mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt


if __name__ == '__main__':
    x = torch.randn(4, 3, 64, 64).cuda()
    model = PromptGenBlock(3, lin_dim=3).cuda()
    out = model(x)
    print(out.shape)
相关推荐
懷淰メ22 分钟前
python3GUI--【AI加持】基于PyQt5+YOLOv8+DeepSeek的智能球体检测系统:(详细介绍)
yolo·目标检测·计算机视觉·pyqt·检测系统·deepseek·球体检测
0***1435 分钟前
React计算机视觉应用
前端·react.js·计算机视觉
阿龙AI日记1 小时前
详解Transformer04:Decoder的结构
人工智能·深度学习·自然语言处理
xier_ran6 小时前
深度学习:生成对抗网络(GAN)详解
人工智能·深度学习·机器学习·gan
海边夕阳20067 小时前
【每天一个AI小知识】:什么是循环神经网络?
人工智能·经验分享·rnn·深度学习·神经网络·机器学习
CV实验室7 小时前
CV论文速递:覆盖视频生成与理解、3D视觉与运动迁移、多模态与跨模态智能、专用场景视觉技术等方向 (11.17-11.21)
人工智能·计算机视觉·3d·论文·音视频·视频生成
【建模先锋】8 小时前
论文复现!基于SAM-BiGRU网络的锂电池RUL预测
深度学习·论文复现·锂电池寿命预测·锂电池数据集·寿命预测
清云逸仙10 小时前
AI Prompt 工程最佳实践:打造结构化的Prompt
人工智能·经验分享·深度学习·ai·ai编程
松岛雾奈.23011 小时前
深度学习--TensorFlow框架使用
深度学习·tensorflow·neo4j
中杯可乐多加冰11 小时前
逻辑控制案例详解|基于smardaten实现OA一体化办公系统逻辑交互
人工智能·深度学习·低代码·oa办公·无代码·一体化平台·逻辑控制