(即插即用模块-特征处理部分) 十九、(NeurIPS 2023) Prompt Block 提示生成 / 交互模块

文章目录

  • 1、Prompt Block
  • 2、代码实现

paper:PromptIR: Prompting for All-in-One Blind Image Restoration

Code:https://github.com/va1shn9v/PromptIR


1、Prompt Block

在解决现有图像恢复模型时,现有研究存在一些局限性: 现有的图像恢复模型通常针对特定的退化类型(如去噪、去雾、去雨)进行训练,这会缺乏泛化能力,难以适应多种退化类型和级别。此外,现有的多退化图像恢复模型通常需要知道输入图像的退化类型,才能选择合适的模型进行恢复,这在实际应用中都是不太现实的。最后,现有的多退化图像恢复模型需要为每种退化类型和级别训练单独的模型,这会导致训练负担过重,且难以在资源受限的平台(如移动设备和边缘设备)上部署。为此,这篇论文提出一种 Prompt Block,其通过引入可学习的提示参数,将退化相关的信息编码到网络中,从而引导网络进行自适应的图像恢复。

Prompt Block 可以分为两个部分:即 Prompt Generation Module(PGM)Prompt Interaction Module(PIM)。具体来说,PGM 的目标是根据输入图像的特征动态生成 prompt 参数,使其能够更好地适应不同的退化类型。而 PIM 通过将 prompt P 与输入特征沿通道维度进行拼接,然后通过 Transformer block 进行处理,实现特征与 prompt 的交互。

对于一个输入特征 X,Prompt Block 的实现过程:

Prompt Generation Module:

  1. 对输入特征进行全局平均池化 (GAP),得到特征向量 v。
  2. 使用 1x1 卷积层对特征向量进行降维,得到紧凑的特征向量。
  3. 对降维后的特征向量进行 softmax 操作,得到 prompt 权重 w。
  4. 使用 prompt 权重 w 对 prompt 组件 Pc 进行加权求和,得到输入条件 prompt P。

Prompt Interaction Module:

  1. 首先将 prompt P 与输入特征 Fl 沿通道维度进行拼接。
  2. 将拼接后的特征通过 Transformer block 进行处理。
  3. 最后将特征经两层卷积处理,输出特征即为经过 Prompt Block 调整后的特征。

Prompt Generation / Interaction Module 结构图:

2、代码实现

python 复制代码
import torch
from torch import nn, einsum
import torch.nn.functional as F


class PromptGenBlock(nn.Module):
    def __init__(self, prompt_dim, prompt_len=5, prompt_size=96, lin_dim=192):
        super(PromptGenBlock, self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size))
        self.linear_layer = nn.Linear(lin_dim, prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        emb = x.mean(dim=(-2, -1))
        prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
                                                                                                                  1, 1,
                                                                                                                  1,
                                                                                                                  1).squeeze(
            1)
        prompt = torch.sum(prompt, dim=1)
        prompt = F.interpolate(prompt, (H, W), mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt


if __name__ == '__main__':
    x = torch.randn(4, 3, 64, 64).cuda()
    model = PromptGenBlock(3, lin_dim=3).cuda()
    out = model(x)
    print(out.shape)
相关推荐
安思派Anspire43 分钟前
LangGraph + MCP + Ollama:构建强大代理 AI 的关键(一)
前端·深度学习·架构
FF-Studio1 小时前
大语言模型(LLM)课程学习(Curriculum Learning)、数据课程(data curriculum)指南:从原理到实践
人工智能·python·深度学习·神经网络·机器学习·语言模型·自然语言处理
CoovallyAIHub2 小时前
YOLO模型优化全攻略:从“准”到“快”,全靠这些招!
深度学习·算法·计算机视觉
G.E.N.3 小时前
开源!RAG竞技场(2):标准RAG算法
大数据·人工智能·深度学习·神经网络·算法·llm·rag
zm-v-159304339865 小时前
ArcGIS 水文分析升级:基于深度学习的流域洪水演进过程模拟
人工智能·深度学习·arcgis
SHIPKING3937 小时前
【机器学习&深度学习】什么是下游任务模型?
人工智能·深度学习·机器学习
伍哥的传说13 小时前
React 各颜色转换方法、颜色值换算工具HEX、RGB/RGBA、HSL/HSLA、HSV、CMYK
深度学习·神经网络·react.js
超龄超能程序猿14 小时前
(三)PS识别:基于噪声分析PS识别的技术实现
图像处理·人工智能·计算机视觉
要努力啊啊啊14 小时前
YOLOv3-SPP Auto-Anchor 聚类调试指南!
人工智能·深度学习·yolo·目标检测·目标跟踪·数据挖掘
**梯度已爆炸**16 小时前
NLP文本预处理
人工智能·深度学习·nlp