OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)

OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)

  • LayoutLM系列模型是微软发布的、文档理解多模态基础模型 领域最重要和有代表性的工作:
    • LayoutLM v2:在一个单一的多模态框架中对文本(text)、布局(layout)和图像(image)之间的交互进行建模。
    • LayoutXLM:LayoutXLM是 LayoutLMv2的多语言扩展版本。
    • LayoutLM v3:借鉴了ViLT和BEIT,不需要经过预训练的视觉backbone,通过MLM、MIM和WPA进行预训练的多模态Transformer。在以视觉为中心的任务上(如文档图像分类和文档布局分析)和以文本为中心的任务上(表单理解、收据理解、文档问答)都表现很好。
  • 今天,我们来了解下LayoutLM v2模型。

1 LayoutLM v2算法原理

  • LayoutLM v2是一种多模态Transformer模型,该模型在预训练阶段整合了文档文本、版式及视觉信息,实现了在一个框架内端到端地学习跨模态交互。同时,将一种空间感知的自注意力机制融入到了Transformer架构中。
  • 除了掩码视觉语言模型(MVLM)预训练策略外,LayoutLM v2还新增了文本-图像对齐(TIA)和文本-图像匹配(TIM) 作为预训练策略,以强化不同模态间的对齐。
  • LayoutLMv2不仅在传统的富视觉文档理解(VrDU)任务上取得了显著的性能提升并达到当时新的最优水平 ,还在文档图像的视觉问题回答(VQA)任务上实现了新突破,这证明了多模态预训练在富视觉文档理解领域的巨大潜力。

1.1 模型结构

  • 模型结构如下图所示,可以看到LayoutLM v2接收文本、视觉及版式信息作为输入,以建立深度的跨模态交互。另外,将spatial-aware的自注意力机制整合到了transformer中。

  • 这里,我们主要看下Embedding层:

    • 文本嵌入

      • 文本嵌入包含三种嵌入:词嵌入代表词本身,一维位置嵌入表示词的位置索引,而段落嵌入用于区分不同的文本段落。

      t i = T o k E m b ( w i ) + P o s E m b 1 D ( i ) + S e g E m b ( s i ) t_i= TokEmb(w_i)+PosEmb1D(i)+SegEmb(s_i) ti=TokEmb(wi)+PosEmb1D(i)+SegEmb(si)

      • 使用WordPiece对OCR文本序列进行分词,并将每个分词(token)分配给特定的段落。接着,在序列的开始添加[CLS]标记,在每个文本段落的末尾添加[SEP]标记。为了使最终序列的长度恰好等于最大序列长度L,在序列末尾额外添加[PAD]填充符。
    • 视觉嵌入

      • 给定一个文档页面图像I,将其调整大小至224×224像素后输入到视觉主干网络中。之后,输出的特征图通过平均池化到固定尺寸,宽度为W,高度为H。接下来,它被展平为长度为W×H(例如:7×7)的视觉嵌入序列,此序列被称为VisTokEmb(I)。然后对每个视觉token嵌入应用线性投影层,以统一其维度与文本嵌入的维度
      • 由于基于CNN的视觉主干无法捕获位置信息,因此添加一维位置嵌入。
      • 对于段落嵌入,将所有视觉令牌附属于视觉段[C]。

      v i = P r o j ( V i s T o k E m b ( I ) i + P o s E m b 1 D ( i ) + S e g E m b ( [ C ] ) v_i= Proj(VisTokEmb(I)_i+PosEmb1D(i)+SegEmb([C]) vi=Proj(VisTokEmb(I)i+PosEmb1D(i)+SegEmb([C])

    • 布局嵌入(2D Position Embeddings)

      • 将所有的坐标标准化并离散化为[0, 1000]范围内的整数,并使用两个嵌入层分别嵌入x轴特征和y轴特征
      • 给定第i个( 0 ≤ i < W × H + L 0 ≤ i < W×H + L 0≤i<W×H+L)文本/视觉token的标准化边界框 b o x i = ( x m i n , x m a x , y m i n , y m a x , w i d t h , h e i g h t ) box_i = (x_{min}, x_{max}, y_{min}, y_{max}, width, height) boxi=(xmin,xmax,ymin,ymax,width,height),布局嵌入层将这六个边界框特征连接起来构建一个token级的2D位置嵌入,即布局嵌入


  • 由于卷积神经网络(CNNs)执行局部变换,因此视觉token嵌入可以一一映射回图像区域,既没有重叠也没有遗漏。
    • 在计算边界框时,视觉token可以被视为均匀划分的网格。
    • 对于特殊token [CLS]、[SEP]和[PAD],会附加一个空边界框boxPAD = (0, 0, 0, 0, 0, 0) 。这意味着这些特殊符号在空间布局上不占用实际区域,但通过这样的空边界框嵌入,模型能够将它们整合到序列中的相应位置上,同时保持空间信息的一致性

1.2 预训练目标及数据

1.2.1 MVLM

  • 采用了掩码视觉-语言建模(Masked Visual-Language Modeling, MVLM)方法,以便模型在跨模态线索的帮助下更好地学习语言方面。
    • 随机掩蔽一些文本token,并要求模型恢复这些被掩蔽的token。
    • 与此同时,布局信息保持不变,这意味着模型了解每个被掩蔽token在页面上的位置。
    • 为了避免视觉线索泄露,在将原始页面图像输入到视觉编码器之前,会先对应掩蔽掉与被掩蔽文本token相对应的图像区域。

1.2.2 TIA

  • Text-Image Alignment(TIA) :随机遮盖图像,然后识别文本对应图像是否被遮盖了。
    • 为了帮助模型学习图像与边界框坐标的空間位置对应关系,提出了细粒度的跨模态对齐任务------文本-图像对齐(Text-Image Alignment, TIA)。
    • 在TIA任务中,随机选择一些文本行,并在其文档图像上的对应图像区域进行遮盖, 称此操作为"遮盖",以避免与MVLM中的"掩码"操作混淆。
    • 预训练期间,在编码器输出之上构建了一个分类层。该层根据文本令牌是否被遮盖(即,[Covered]或[Not Covered])预测每个文本令牌的标签,并计算二元交叉熵损失
    • 考虑到输入图像的分辨率有限,且某些文档元素(如图表中的符号和线条)可能看起来像被遮盖的文本区域,寻找单词大小的遮盖图像区域的任务可能会存在噪声。因此,遮盖操作是在行级别进行的
    • 当MVLM和TIA同时执行时,MVLM中被掩蔽的令牌的TIA损失不予考虑。这防止了模型学习从[MASK]到[Covered]这种无用但直观的对应关系。

1.2.3 TIM

  • Text-Image Matching(TIM):使用[CLS]来判断给出的图片特征与文本特征是否属于同一个页面。
  • 为了帮助模型学习文档图像与文本内容之间的对应关系,采用了较为粗粒度的跨模态对齐任务,即文本-图像匹配(Text-Image Matching, TIM)。
  • 将[CLS]位置的输出表示送入一个分类器,以预测图像和文本是否来自同一文档页面。正常的配对输入被视为正样本
  • 为了构建负样本,图像要么被另一文档的页面图像替换,要么被移除。
  • 为防止模型通过寻找任务特定特征来作弊,对负面样本中的图像也执行相同的掩码和遮盖操作。在负面样本中,TIA的目标标签全部设置为[Covered]

1.2.4 预训练数据

  • 为了预训练和评估LayoutLMv2模型,作者从富含视觉元素的文档理解领域中选择了广泛的数据集。

  • 使用IIT-CDIP作为预训练数据集。

1.3 模型微调

  • 文档级别分类任务 RVL-CDIP中,使用[CLS]输出以及池化的视觉令牌表示作为全局特征
  • 对于提取式问答任务DocVQA及其他四个实体提取任务,在LayoutLMv2输出的文本部分上构建特定任务的头部层。在DocVQA论文中,实验结果显示,在SQuAD数据集上微调过的BERT模型比原始BERT模型表现更优。受此启发,增加了一个额外的设置:首先在问题生成(Question Generation, QG)数据集上微调LayoutLMv2,随后再在DocVQA数据集上微调。这个QG数据集包含近百万对由训练于SQuAD数据集的生成模型产生的问题-答案对。

1.4 LayoutXLM模型结构

  • LayoutXLM是 LayoutLMv2的多语言扩展版本。为了准确评估LayoutXLM,论文中还引入了一个多语言表单理解基准数据集,名为XFUND ,该数据集包含了7种语言(中文、日语、西班牙语、法语、意大利语、德语、葡萄牙语)的表单理解样本,并为每种语言的手工标注了键值对。
  • 论文链接:https://arxiv.org/pdf/2104.08836
  • LayoutXLM预训练策略,同LayoutLMv2
  • 该框架如下图所示:
    • 模型接收来自三种不同模态的信息,即文本、布局和图像,分别使用文本嵌入、布局嵌入和视觉嵌入层进行编码。文本和图像嵌入被连接在一起,然后加上布局嵌入以获得输入嵌入。
    • 输入嵌入通过带有空间感知自注意力机制的多模态Transformer进行编码。
    • 最后,输出的上下文表示可以用于后续的任务特定层。

1.5 VI-LayoutXLM

  • 百度在PP-StructureV2中,针对 LayoutXLM 进行改进,得到了VI-LayoutXLM。

  • 论文链接:https://arxiv.org/pdf/2210.05391

  • 模型部分改进如下:

    • LayoutLMv2 以及 LayoutXLM 中引入视觉骨干网络,用于提取视觉特征,并与后续的 text embedding 进行联合,作为多模态的输入 embedding。但是该模块为基于 ResNet_x101_64x4d 的特征提取网络,特征抽取阶段耗时严重。
    • 因此,移除视觉特征提取模块,同时仍然保留文本、位置以及布局等信息,最终发现针对 LayoutXLM 进行改进,下游 SER 任务精度无损,针对 LayoutLMv2 进行改进,下游 SER 任务精度仅降低2.1%,而模型大小减小了约340M。

2 VI-LayoutXLM在发票数据集上的应用

  • 关键信息抽取 (Key Information Extraction, KIE)指的是是从文本或者图像中,抽取出关键的信息。

    • 针对文档图像的关键信息抽取任务作为OCR的下游任务,存在非常多的实际应用场景,如表单识别、车票信息抽取、身份证信息抽取等。
    • 文档图像中的KIE一般包含2个子任务,示意图如下图所示。
      • SER: 语义实体识别 (Semantic Entity Recognition),对每一个检测到的文本进行分类,如将其分为姓名,身份证。如下图中的黑色框和红色框。
      • RE: 关系抽取 (Relation Extraction),对每一个检测到的文本进行分类,如将其分为问题 (key) 和答案 (value) 。然后对每一个问题找到对应的答案,相当于完成key-value的匹配过程。如下图中的红色框和黑色框分别代表问题和答案,黄色线代表问题和答案之间的对应关系。
  • 除了视觉特征无关的多模态预训练模型结构,paddleocr中在KIE任务上,还有两个主要的优化策略:

    • TB-YX:考虑阅读顺序的文本行排序逻辑
      • 文本阅读顺序对于信息抽取与文本理解等任务至关重要,传统多模态模型中,没有考虑不同 OCR 工具可能产生的不正确阅读顺序,而模型输入中包含位置编码,阅读顺序会直接影响预测结果
      • 在预处理中,对文本行按照从上到下,从左到右(YX)的顺序进行排序,为防止文本行位置轻微干扰带来的排序结果不稳定问题,在排序的过程中,引入位置偏移阈值 Th,对于 Y 方向距离小于 Th 的2个文本内容,使用 X 方向的位置从左到右进行排序。
    • UDML:联合互学习知识蒸馏策略
      • UDML(Unified-Deep Mutual Learning)联合互学习是 PP-OCRv2 与 PP-OCRv3 中采用的对于文本识别非常有效的提升模型效果的策略。
      • 在训练时,引入2个完全相同的模型进行互学习,计算2个模型之间的互蒸馏损失函数(DML loss),同时对 transformer 中间层的输出结果计算距离损失函数(L2 loss)。
      • 使用该策略,最终 XFUND 数据集上,SER 任务 F1 指标提升0.6%,RE 任务 F1 指标提升5.01%。
  • KIE常用思路有如下两种:

    • 一种是SER:

      • 直接使用SER,获取关键信息的类别;常用于关键信息类别固定的场景。
      • 以身份证场景为例, 关键信息一般包含姓名性别民族等,我们直接将对应的字段标注为特定的类别即可,如下图所示:
      • 注意:

        • 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为other类别,相当于背景信息。如在身份证场景中,如果我们不关注性别信息,那么可以将"性别"与"男"这2个字段的类别均标注为other
        • 标注过程中,需要以文本行为单位进行标注,无需标注单个字符的位置信息。

        数据量方面,一般来说,对于比较固定的场景,50张 左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。

    • 一种是SER+RE:

      • 联合使用SER+RE,先利用SER找到key和value,然后再利用RE进行匹配;常用于关系类别不固定的场景。
      • 以身份证场景为例, 关键信息一般包含姓名性别民族等关键信息。在SER阶段,我们需要识别所有的question (key) 与answer (value) 。每个字段的类别信息(label字段)可以是question、answer或者other(与待抽取的关键信息无关的字段)
      • 在RE阶段,需要标注每个字段的的id与连接信息,如下图所示:
        • 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如[[0, 1], [0, 2]]
        • 数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。
    • 我们参考案例:https://aistudio.baidu.com/projectdetail/4823162(项目里提供了发票数据集),来对VI-LayoutXLM模型有更深的认识。

2.1 语义实体识别 (SER)

2.1.1 模型构建

  • 我这里不用命令行执行,在paddleocr\tests目录下创建一个py文件执行训练过程

  • 我们复制一份paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml文件到paddleocr\tests\configs进行修改(参考上面项目链接进行修改),发票数据集在上面项目中已提供,模型部分的配置文件如下:

    shell 复制代码
    Architecture:
      model_type: &model_type "kie"
      name: DistillationModel
      algorithm: Distillation
      Models:
        Teacher:
          pretrained:
          freeze_params: false
          return_all_feats: true
          model_type: *model_type
          algorithm: &algorithm "LayoutXLM"
          Transform:
          Backbone:
            name: LayoutXLMForSer
            pretrained: True             # 会利用paddle-nlp加载预训练模型
            # one of base or vi
            mode: vi
            checkpoints:
            num_classes: &num_classes 5  # 采用BIO的标注,训练需要修改
        Student:
          pretrained:
          freeze_params: false
          return_all_feats: true
          model_type: *model_type
          algorithm: *algorithm
          Transform:
          Backbone:
            name: LayoutXLMForSer
            pretrained: True
            # one of base or vi
            mode: vi
            checkpoints:
            num_classes: *num_classes
  • 通过下面的py文件,我们就可以愉快的查看源码了。

python 复制代码
def train_kie_token_ser_demo():
    from tools.train import program, set_seed, main
    # 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml
    config, device, logger, vdl_writer = program.preprocess(is_train=True)

    ###############修改配置(也可在yml文件中修改)##################
    # 评估频率
    config["Global"]["eval_batch_step"] = [0, 200]
    # log的打印频率
    config["Global"]["print_batch_step"] = 50
    # 训练的epochs
    config["Global"]["epoch_num"] = 1
    # 随机种子
    seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
    set_seed(seed)

    ###############模型训练##################
    main(config, device, logger, vdl_writer, seed)


def train_kie_token_re_demo():
    from tools.train import program, set_seed, main
    # 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\re_vi_layoutxlm_xfund_zh_udml.yml
    config, device, logger, vdl_writer = program.preprocess(is_train=True)

    ###############修改配置(也可在yml文件中修改)##################
    # 评估频率
    config["Global"]["eval_batch_step"] = [0, 200]
    # log的打印频率
    config["Global"]["print_batch_step"] = 50
    # 训练的epochs
    config["Global"]["epoch_num"] = 1
    # 随机种子
    seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
    set_seed(seed)

    ###############模型训练##################
    main(config, device, logger, vdl_writer, seed)

if __name__ == '__main__':
    train_kie_token_ser_demo()
    # train_kie_token_re_demo()

LayoutXLMForTokenClassification

  • 首先,利用LayoutXLMModel提取特征(文本、布局信息)
  • 然后,利用文本部分的特征进行BIO多分类
python 复制代码
# paddleocr.ppocr.modeling.backbones.vqa_layoutlm.py
class LayoutXLMForTokenClassification(LayoutXLMPretrainedModel):
    def __init__(self, config: LayoutXLMConfig):
        super(LayoutXLMForTokenClassification, self).__init__(config)
        self.num_classes = config.num_labels
        self.layoutxlm = LayoutXLMModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_classes)

    ......

    def forward(
        self,
        input_ids=None,
        bbox=None,
        image=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
    ):
        # 1、经过12层的Transformer Block Encoder
        outputs = self.layoutxlm(
            input_ids=input_ids,
            bbox=bbox,
            image=image,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
        )
        seq_length = input_ids.shape[1]
        
        # sequence out and image out
        # 2、进行BIO多分类
        # sequence_output: (bs, 561, 768) -> (bs, 512, 768) -> (bs, 512, 5)
        sequence_output = outputs[0][:, :seq_length]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        hidden_states = {
            f"hidden_states_{idx}": outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)
        }
        if self.training:
            outputs = (logits, hidden_states)
        else:
            outputs = (logits,)

        ......

        return outputs

LayoutXLMModel

这里我们主要看下LayoutXLMModel模型中,文本的embedding和视觉部分的embedding。

  • 文本的embedding:

    • word_embeddings:对tokenizer后的input_ids进行word_embeddings,shape变化:(bs, 512) -> (bs, 512, 768)

    • position_embeddings(1D position embedding):对文本部分的position_ids进行embeding,shape变化:(bs, 512) -> (bs, 512, 768)。这里,文本和视觉的position_embeddings是共享的。

    • spatial_position_embeddings:这里shape变化为(bs, 512, 4) -> (bs, 512, 768),是将每一个bbox信息的(x_min, y_min, x_max, y_max, h, w)编码,然后concat得到,代码如下所示。注意:如果一个bbox内的文字,被切分为多个token,那么这些token的bbox信息是一致的。

      python 复制代码
          # paddlenlp.transformers.layoutxlm.modeling.py
          def _cal_spatial_position_embeddings(self, bbox):
              try:
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
              except IndexError as e:
                  raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e
              # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
              h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
              # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
              w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
      
              #  [x_min, y_min, x_max, y_max, h, w] concat -> (bs, embdedding_dim, 128*6)
              spatial_position_embeddings = paddle.concat(
                  [
                      left_position_embeddings,
                      upper_position_embeddings,
                      right_position_embeddings,
                      lower_position_embeddings,
                      h_position_embeddings,
                      w_position_embeddings,
                  ],
                  axis=-1,
              )
              return spatial_position_embeddings
    • token_type_embeddings:这里的token_type_ids全为0,shape变化为(bs, 512) -> (bs, 512, 768)

  • 视觉部分的embedding:

    • position_embeddings(1D position embedding):shape变化为(bs, 49) -> (bs, 49, 768)。视觉部分的position ids为:[0, 1, 2, ..., 48] -> (bs, 49)。这里虽然去除了视觉提取,但是position ids按照图像224×224经过降采样32倍后的feature map:7×7进行生成。这里,文本和视觉的position_embeddings是共享的;
    • spatial_position_embeddings:视觉部分布局信息,即bbox的生成的核心逻辑是:7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token。shape变化为(bs, 49, 4) -> (bs, 49, 768)
    • visual_segment_embedding
  • 最终,将文本的embedding和视觉部分的embedding送入到12层的Transformer Encoder Block提取特征。

python 复制代码
# paddlenlp.transformers.layoutxlm.modeling.py
@register_base_model
class LayoutXLMModel(LayoutXLMPretrainedModel):

    def __init__(self, config: LayoutXLMConfig):
        super(LayoutXLMModel, self).__init__(config)
        self.config = config
        self.use_visual_backbone = config.use_visual_backbone
        self.has_visual_segment_embedding = config.has_visual_segment_embedding
        self.embeddings = LayoutXLMEmbeddings(config)

        if self.use_visual_backbone is True:
            self.visual = VisualBackbone(config)
            self.visual.stop_gradient = True
            self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)

        if self.has_visual_segment_embedding:
            self.visual_segment_embedding = self.create_parameter(
                shape=[
                    config.hidden_size,
                ],
                dtype=paddle.float32,
            )
        self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
        self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)

        self.encoder = LayoutXLMEncoder(config)
        self.pooler = LayoutXLMPooler(config)



    def _calc_visual_bbox(self, image_feature_pool_shape, bbox, visual_shape):
        """
           视觉部分布局信息,即bbox的生成:
                 - image_feature_pool_shape:(7, 7, 256)
                 - 文字token的bbox信息:(bs, 512, 4)
                 - visual_shape:[bs, 49]
        """
        # 首先,生成一个序列[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000]
        # 然后,离散化为[0, 1000],即[0, 142, 285, 428, 571, 714, 857, 1000]
        visual_bbox_x = (
            paddle.arange(
                0,
                1000 * (image_feature_pool_shape[1] + 1),
                1000,
                dtype=bbox.dtype,
            )
            // image_feature_pool_shape[1]
        )
        visual_bbox_y = (
            paddle.arange(
                0,
                1000 * (image_feature_pool_shape[0] + 1),
                1000,
                dtype=bbox.dtype,
            )
            // image_feature_pool_shape[0]
        )

        expand_shape = image_feature_pool_shape[0:2] # (7, 7)
        # 7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token
        # visual_bbox shape = (7×7, 4)
        visual_bbox = paddle.stack(
            [
                visual_bbox_x[:-1].expand(expand_shape),
                visual_bbox_y[:-1].expand(expand_shape[::-1]).transpose([1, 0]),
                visual_bbox_x[1:].expand(expand_shape),
                visual_bbox_y[1:].expand(expand_shape[::-1]).transpose([1, 0]),
            ],
            axis=-1,
        ).reshape([expand_shape[0] * expand_shape[1], paddle.shape(bbox)[-1]])
        # 扩展到bs个样本, (7×7, 4) -> (bs, 7×7, 4)
        visual_bbox = visual_bbox.expand([visual_shape[0], visual_bbox.shape[0], visual_bbox.shape[1]])
        return visual_bbox

    def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids):
        """
          文本部分进行embeddings:
                  word_embeddings
                + position_embeddings(文本和视觉的position_embeddings是共享的)
                + spatial_position_embeddings
                + token_type_embeddings
        """
        # (bs, 512) -> (bs, 512, 768)
        words_embeddings = self.embeddings.word_embeddings(input_ids)
        # (bs, 512) -> (bs, 512, 768)
        position_embeddings = self.embeddings.position_embeddings(position_ids)
        # (bs, 512, 4) -> (bs, 512, 768)
        spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
        # (bs, 512) -> (bs, 512, 768)
        token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
        # 4种embedding相加
        embeddings = words_embeddings + position_embeddings + spatial_position_embeddings + token_type_embeddings
        # LayerNorm + dropout
        embeddings = self.embeddings.LayerNorm(embeddings)
        embeddings = self.embeddings.dropout(embeddings)
        return embeddings


    def _calc_img_embeddings(self, image, bbox, position_ids):
        """
            视觉部分进行embedding:
                    position_embeddings(文本和视觉的position_embeddings是共享的)
                +   spatial_position_embeddings
                +   visual_segment_embedding
        """
        use_image_info = self.use_visual_backbone and image is not None
        # (bs, 49) -> (bs, 49, 768)
        position_embeddings = self.embeddings.position_embeddings(position_ids)
        # (bs, 49, 4) -> (bs, 49, 768)
        spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
        if use_image_info is True:
            visual_embeddings = self.visual_proj(self.visual(image.astype(paddle.float32)))
            embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
        else:
            # embedding相加
            embeddings = position_embeddings + spatial_position_embeddings

        if self.has_visual_segment_embedding:
            # self.visual_segment_embedding shape = (768)
            embeddings += self.visual_segment_embedding

        #  visual_LayerNorm + visual_dropout
        embeddings = self.visual_LayerNorm(embeddings)
        embeddings = self.visual_dropout(embeddings)
        return embeddings

    
    def forward(
        self,
        input_ids=None,
        bbox=None,
        image=None,
        token_type_ids=None,
        position_ids=None,
        attention_mask=None,
        head_mask=None,
        output_hidden_states=False,
        output_attentions=False,
    ):
        input_shape = paddle.shape(input_ids)
        visual_shape = list(input_shape)
        visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]

        # 视觉部分的bbox的生成
        # 视觉token被视为均匀划分的网格
        # 生成的bbox信息:feature_map(7×7)网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token
        visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, visual_shape)

        # 1、2D position embedding(文本部分bbox+视觉部分bbox)
        # (bs, 512, 4) + (bs, 49, 4) -> (bs, 561, 4)
        final_bbox = paddle.concat([bbox, visual_bbox], axis=1)
        if attention_mask is None:
            attention_mask = paddle.ones(input_shape)

        if self.use_visual_backbone is True:
            # 使用视觉部分的backbone
            visual_attention_mask = paddle.ones(visual_shape)
        else:
            # 移除视觉特征提取模块,mask全设置为0
            visual_attention_mask = paddle.zeros(visual_shape)

        attention_mask = attention_mask.astype(visual_attention_mask.dtype)
        # concat后attention_mask:(bs, 512) + (bs, 49) -> (bs, 561)
        final_attention_mask = paddle.concat([attention_mask, visual_attention_mask], axis=1)

        if token_type_ids is None:
            token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64)


        # 2、1D position embedding(文本部分+视觉部分) (bs, 512) + (bs, 49) -> (bs, 561)
        if position_ids is None:
            # 文本部分的position embedding
            seq_length = input_shape[1]
            position_ids = self.embeddings.position_ids[:, :seq_length]
            position_ids = position_ids.expand(input_shape)

        # 视觉部分的position embedding
        # [0, 1, 2, ..., 48] -> (bs, 49)
        visual_position_ids = paddle.arange(0, visual_shape[1]).expand([input_shape[0], visual_shape[1]])
        final_position_ids = paddle.concat([position_ids, visual_position_ids], axis=1)

        if bbox is None:
            bbox = paddle.zeros(input_shape + [4])

        # 3、 text embedding & visual  (bs, 512, 768) + (bs, 49, 768) -> (bs, 561, 768)
        # 文本部分进行embdedding (bs, 512) -> (bs, 512, 768)
        text_layout_emb = self._calc_text_embeddings(
            input_ids=input_ids,
            bbox=bbox,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        )
        # 视觉部分进行embedding(注意此时没有image,仅有视觉的bbox以及position_ids)
        visual_emb = self._calc_img_embeddings(
            image=image,
            bbox=visual_bbox,
            position_ids=visual_position_ids,
        )
        final_emb = paddle.concat([text_layout_emb, visual_emb], axis=1)
        # (bs, 561) -> (bs, 1, 1, 561)
        extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        else:
            head_mask = [None] * self.config.num_hidden_layers

        # 经过Transformer Encoder Block(12层)
        encoder_outputs = self.encoder(
            final_emb,                        # 文本&视觉部分的embedding , shape=(bs, 561, 768)
            extended_attention_mask,          # attention_mask        , shape=(bs, 1, 1, 561)
            bbox=final_bbox,                  # 2D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561, 4)
            position_ids=final_position_ids,  # 1D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561)
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        # sequence_output shape = (bs, 561, 768)
        sequence_output = encoder_outputs[0]
        # pooled_output shape = (bs, 768)
        pooled_output = self.pooler(sequence_output)
        return sequence_output, pooled_output, encoder_outputs[1]

2.1.2 损失计算

  • 由于使用了UDML:联合互学习知识蒸馏策略,损失计算的配置如下:
python 复制代码
Loss:
  name: CombinedLoss                     # ppocr.losses.combined_loss.CombinedLoss
  loss_config_list:
  - DistillationVQASerTokenLayoutLMLoss: # GT loss   ppocr.losses.distillation_loss.DistillationVQASerTokenLayoutLMLoss
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: backbone_out
      num_classes: *num_classes
  - DistillationSERDMLLoss:              # DML loss  ppocr.losses.distillation_loss.DistillationSERDMLLoss
      weight: 1.0
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: backbone_out
  - DistillationVQADistanceLoss:         # S5 loss  ppocr.losses.distillation_loss.DistillationVQADistanceLoss
      weight: 0.5
      mode: "l2"
      model_name_pairs:
        - ["Student", "Teacher"]
      key: hidden_states_5
      name: "loss_5"
  - DistillationVQADistanceLoss:         # S8 loss  ppocr.losses.distillation_loss.DistillationVQADistanceLoss
      weight: 0.5
      mode: "l2"
      model_name_pairs:
        - ["Student", "Teacher"]
      key: hidden_states_8
      name: "loss_8"
  • 如下所示,在DistillationModel中,Teacher和Student模型分别进行前向过程
python 复制代码
# paddleocr.ppocr.modeling.architectures.distillation_model.py
class DistillationModel(nn.Layer):
    def __init__(self, config):
        """
        the module for OCR distillation.
        args:
            config (dict): the super parameters for module.
        """
        super().__init__()
        self.model_list = []
        self.model_name_list = []
        for key in config["Models"]:
            model_config = config["Models"][key]
            freeze_params = False
            pretrained = None
            if "freeze_params" in model_config:
                freeze_params = model_config.pop("freeze_params")
            if "pretrained" in model_config:
                pretrained = model_config.pop("pretrained")
            model = BaseModel(model_config)
            if pretrained is not None:
                load_pretrained_params(model, pretrained)
            if freeze_params:
                for param in model.parameters():
                    param.trainable = False
            self.model_list.append(self.add_sublayer(key, model))
            self.model_name_list.append(key)

    def forward(self, x, data=None):
        result_dict = dict()
        # 执行所有模型的前向过程, 例如:Teacher和Student模型
        for idx, model_name in enumerate(self.model_name_list):
            result_dict[model_name] = self.model_list[idx](x, data)
        return result_dict
  • 在CombinedLoss中遍历配置的损失函数,分别计算损失,最后相加最为总损失
python 复制代码
# paddleocr.ppocr.losses.combined_loss.py
class CombinedLoss(nn.Layer):
    """
    CombinedLoss:
        a combionation of loss function
    """

    def __init__(self, loss_config_list=None):
        super().__init__()
        self.loss_func = []
        self.loss_weight = []
        assert isinstance(loss_config_list, list), "operator config should be a list"
        ......

    def forward(self, input, batch, **kargs):
        # input包含Teacher模型以及Student模型的输出结果
        # batch是批次数据,里面包含label
        loss_dict = {}
        loss_all = 0.0
        # 遍历配置的所有的损失函数,计算损失
        for idx, loss_func in enumerate(self.loss_func):
            loss = loss_func(input, batch, **kargs)
            if isinstance(loss, paddle.Tensor):
                loss = {"loss_{}_{}".format(str(loss), idx): loss}

            weight = self.loss_weight[idx]

            loss = {key: loss[key] * weight for key in loss}

            if "loss" in loss:
                loss_all += loss["loss"]
            else:
                loss_all += paddle.add_n(list(loss.values()))
            loss_dict.update(loss)
        loss_dict["loss"] = loss_all
        return loss_dict
  • 我们看下具体配置的损失函数:

    • DistillationVQASerTokenLayoutLMLoss的实质就是每个模型分别计算NER任务的CrossEntropyLoss,即GT loss:

      python 复制代码
      class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
          def __init__(self, num_classes, model_name_list=[], key=None, name="loss_ser"):
              super().__init__(num_classes=num_classes)
              self.model_name_list = model_name_list
              self.key = key
              self.name = name
      
          def forward(self, predicts, batch):
              loss_dict = dict()
              # 遍历Teacher模型、Student模型
              for idx, model_name in enumerate(self.model_name_list):
                  # 先从predicts取出相关模型的预测结果
                  out = predicts[model_name]
                  # 然后,从out中取出key(即配置文件中配置的backbone_out)的值
                  if self.key is not None:
                      out = out[self.key]
                  # 调用父类,计算损失
                  loss = super().forward(out, batch)
                  loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
              return loss_dict
      
      # DistillationVQASerTokenLayoutLMLoss的父类
      class VQASerTokenLayoutLMLoss(nn.Layer):
          def __init__(self, num_classes, key=None):
              super().__init__()
              self.loss_class = nn.CrossEntropyLoss()
              self.num_classes = num_classes
              self.ignore_index = self.loss_class.ignore_index
              self.key = key
      
          def forward(self, predicts, batch):
              if isinstance(predicts, dict) and self.key is not None:
                  predicts = predicts[self.key]
              labels = batch[5]           # (bs, 512)
              attention_mask = batch[2]   # (bs, 512)
              if attention_mask is not None:
                  active_loss = (
                      attention_mask.reshape(
                          [
                              -1,
                          ]
                      )
                      == 1
                  )
                  # active_output_shape = (bs, 512, 5) -> (bs*512, 5)
                  active_output = predicts.reshape([-1, self.num_classes])[active_loss]
                  # active_label_shape = bs*512
                  active_label = labels.reshape(
                      [
                          -1,
                      ]
                  )[active_loss]
                  # 交叉熵损失函数
                  loss = self.loss_class(active_output, active_label)
              else:
                  loss = self.loss_class(
                      predicts.reshape([-1, self.num_classes]),
                      labels.reshape(
                          [
                              -1,
                          ]
                      ),
                  )
              return {"loss": loss}
    • DistillationSERDMLLoss实质是计算Techaer和Student模型之间的互蒸馏损失函数,即KL散度。

      python 复制代码
      class DistillationSERDMLLoss(DMLLoss):
          """ """
      
          def __init__(
              self,
              act="softmax",
              use_log=True,
              num_classes=7,
              model_name_pairs=[],
              key=None,
              name="loss_dml_ser",
          ):
              super().__init__(act=act, use_log=use_log)
              assert isinstance(model_name_pairs, list)
              self.key = key
              self.name = name
              self.num_classes = num_classes
              self.model_name_pairs = model_name_pairs
      
          def forward(self, predicts, batch):
              loss_dict = dict()
              # 遍历Teacher模型、Student模型
              for idx, pair in enumerate(self.model_name_pairs):
                  # 取出Teacher模型以及Student模型中的结果
                  out1 = predicts[pair[0]]
                  out2 = predicts[pair[1]]
                  if self.key is not None:
                      # 取出backbone_out
                      out1 = out1[self.key]
                      out2 = out2[self.key]
                  out1 = out1.reshape([-1, out1.shape[-1]])
                  out2 = out2.reshape([-1, out2.shape[-1]])
      
                  attention_mask = batch[2]
                  if attention_mask is not None:
                      active_output = (
                          attention_mask.reshape(
                              [
                                  -1,
                              ]
                          )
                          == 1
                      )
                      out1 = out1[active_output]
                      out2 = out2[active_output]
                  # 调用父类的方法
                  loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, out2)
      
              return loss_dict
      
      # DistillationSERDMLLoss的父类   
      class DMLLoss(nn.Layer):
          """
          DMLLoss
          """
      
          def __init__(self, act=None, use_log=False):
              super().__init__()
              if act is not None:
                  assert act in ["softmax", "sigmoid"]
              if act == "softmax":
                  self.act = nn.Softmax(axis=-1)
              elif act == "sigmoid":
                  self.act = nn.Sigmoid()
              else:
                  self.act = None
      
              self.use_log = use_log
              self.jskl_loss = KLJSLoss(mode="kl")
      
          def _kldiv(self, x, target):
              """
                  计算两个概率分布之间的KL散度:
                      KL散度的公式是 KL(P||Q) = ΣP(x) * log(P(x)/Q(x)),这里将其重写为ΣP(x)*(log(P(x)) - log(Q(x)))
                      即target * (paddle.log(target + eps) - x)
              """
              eps = 1.0e-10
              loss = target * (paddle.log(target + eps) - x)
              # batch mean loss
              loss = paddle.sum(loss) / loss.shape[0]
              return loss
      
          def forward(self, out1, out2):
              if self.act is not None:
                  out1 = self.act(out1) + 1e-10
                  out2 = self.act(out2) + 1e-10
              if self.use_log:
                  # 计算KL散度
                  # for recognition distillation, log is needed for feature map
                  log_out1 = paddle.log(out1)
                  log_out2 = paddle.log(out2)
                  loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
              else:
                  # for detection distillation log is not needed
                  loss = self.jskl_loss(out1, out2)
              return loss    
    • DistillationVQADistanceLoss,本质是对 transformer 中间层的输出结果计算距离损失函数(L2 loss)

      python 复制代码
      # DistillationVQADistanceLoss的父类
      class DistanceLoss(nn.Layer):
          """
          DistanceLoss:
              mode: loss mode
          """
      
          def __init__(self, mode="l2", **kargs):
              super().__init__()
              assert mode in ["l1", "l2", "smooth_l1"]
              if mode == "l1":
                  self.loss_func = nn.L1Loss(**kargs)
              elif mode == "l2":
                  self.loss_func = nn.MSELoss(**kargs)
              elif mode == "smooth_l1":
                  self.loss_func = nn.SmoothL1Loss(**kargs)
      
          def forward(self, x, y):
              return self.loss_func(x, y)

      其他部分,诸如数据集的加载、构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。

2.2 关系抽取(RE)

  • 我们这里,看下模型的构建部分代码,其他代码,大家可以查看源码,不再赘述。
python 复制代码
# paddlenlp.transformers.layoutxlm.modeling.py
class LayoutXLMForRelationExtraction(LayoutXLMPretrainedModel):
    def __init__(self, config: LayoutXLMConfig):
        super(LayoutXLMForRelationExtraction, self).__init__(config)

        self.layoutxlm = LayoutXLMModel(config)

        self.extractor = REDecoder(config.hidden_size, config.hidden_dropout_prob)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    
    ......
    def forward(
        self,
        input_ids,
        bbox,
        image=None,
        attention_mask=None,
        entities=None,
        relations=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
    ):
        # 1、经过12层的Transformer Block Encoder
        outputs = self.layoutxlm(
            input_ids=input_ids,            # (bs, 512)
            bbox=bbox,                      # (bs, 512, 4)
            image=image,                    # None
            attention_mask=attention_mask,  # (bs, 512)
            token_type_ids=token_type_ids,  # (bs. 512)
            position_ids=position_ids,      # None
            head_mask=head_mask,            # None
        )
        seq_length = input_ids.shape[1]
        # 最后一层输出
        # sequence_output_shape = (bs, 512, 768)
        sequence_output = outputs[0][:, :seq_length]
        sequence_output = self.dropout(sequence_output)

        # 2、计算loss和预测关系
        loss, pred_relations = self.extractor(sequence_output, entities, relations)

        hidden_states = [outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)]
        hidden_states = paddle.stack(hidden_states, axis=1)
        # 3、返回结果
        res = dict(loss=loss, pred_relations=pred_relations, hidden_states=hidden_states)
        return res
  • 主要代码在REDecoder中

    • 首先,构建构建关系对的正负样本
    • 然后,获取关系头(question)、关系尾(answer)对应的特征信息
      • 获取关系头(即question)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系头(question)经过Embedding后的结果(shape=(100, 768))进行concat
      • 获取关系尾(即answer)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系尾(answer)经过Embedding后的结果(shape=(100, 768))进行concat
    • 利用提取到的head_repr、tail_repr特征信息进行关系分类
    • 最后,利用预测结果,计算交叉熵损失等
    • 下面,给出一个relations和entities示例,方便理解。
python 复制代码
class REDecoder(nn.Layer):
    def __init__(self, hidden_size=768, hidden_dropout_prob=0.1):
        super(REDecoder, self).__init__()
        self.entity_emb = nn.Embedding(3, hidden_size)
        # 100代表:100个关系对
        # (100, 1536) -> (100, 768) -> (100, 384)
        projection = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(hidden_dropout_prob),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(hidden_dropout_prob),
        )
        self.ffnn_head = copy.deepcopy(projection)
        self.ffnn_tail = copy.deepcopy(projection)
        # (100, 384) -> (100, 2)
        self.rel_classifier = BiaffineAttention(hidden_size // 2, 2)
        self.loss_fct = CrossEntropyLoss()

    def build_relation(self, relations, entities):
        """
            relations_shape = (bs, 262145, 2)
            entities_shape  = (bs, 513, 3)
            注:
                relations第1个数组代表实际长度,例如:[10, 10],代表:关系对(QUESTION-ANSWER)只有10个,其他为填充
                entities第1个数组代表实际长度,例如:[20, 20, 20],代表:实例(QUESTION或ANSWER)只有20个,其他为填充
        """
        batch_size, max_seq_len = paddle.shape(entities)[:2]
        # new_relations_shape = (bs, 513*513, 3), 初始化为-1
        new_relations = paddle.full(
            shape=[batch_size, max_seq_len * max_seq_len, 3], fill_value=-1, dtype=relations.dtype
        )
        for b in range(batch_size):
            if entities[b, 0, 0] <= 2:
                entitie_new = paddle.full(shape=[512, 3], fill_value=-1, dtype=entities.dtype)
                entitie_new[0, :] = 2
                entitie_new[1:3, 0] = 0  # start
                entitie_new[1:3, 1] = 1  # end
                entitie_new[1:3, 2] = 0  # label
                entities[b] = entitie_new
            # 实体label_shape为: [2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]
            # all_possible_relations1为: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]  QUESTION
            # all_possible_relations2为: [0 , 3 , 5 , 7 , 9 , 11, 15, 16, 17, 18]  ANSWER
            entitie_label = entities[b, 1 : entities[b, 0, 2] + 1, 2]
            all_possible_relations1 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)
            all_possible_relations1 = all_possible_relations1[entitie_label == 1]
            all_possible_relations2 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)
            all_possible_relations2 = all_possible_relations2[entitie_label == 2]

            # 所有可能的关系:all_possible_relations_shape:(100, 2)
            # [
            #   [1, 0],  [1, 3], ... , [1, 18],
            #   [2, 0],  [2, 3], ... , [2, 18],
            #           ......
            #   [19, 0], [19, 3], ... , [19, 18]
            # ]
            all_possible_relations = paddle.stack(
                paddle.meshgrid(all_possible_relations1, all_possible_relations2), axis=2
            ).reshape([-1, 2])
            if len(all_possible_relations) == 0:
                all_possible_relations = paddle.full_like(all_possible_relations, fill_value=-1, dtype=entities.dtype)
                all_possible_relations[0, 0] = 0
                all_possible_relations[0, 1] = 1
            # relation_head: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]
            # relation_tail: [0 , 3 , 5 , 7 , 9 , 11, 17, 15, 16, 18]
            relation_head = relations[b, 1 : relations[b, 0, 0] + 1, 0]
            relation_tail = relations[b, 1 : relations[b, 0, 1] + 1, 1]
            # positive_relations_shape: (10, 2)
            positive_relations = paddle.stack([relation_head, relation_tail], axis=1)
            # (100, 2) -> (100, 10, 2)
            all_possible_relations_repeat = all_possible_relations.unsqueeze(axis=1).tile(
                [1, len(positive_relations), 1]
            )
            # (100, 2) -> (100, 10, 2)
            positive_relations_repeat = positive_relations.unsqueeze(axis=0).tile([len(all_possible_relations), 1, 1])
            # mask shape = (100, 10)
            mask = paddle.all(all_possible_relations_repeat == positive_relations_repeat, axis=2)
            # 获取关系对负样本
            # negative_mask = paddle.any(mask, axis=1) is False
            negative_mask = ~paddle.any(mask, axis=1)
            negative_relations = all_possible_relations[negative_mask]

            # 获取关系对正样本
            # positive_mask = paddle.any(mask, axis=0) is True
            positive_mask = paddle.any(mask, axis=0)
            positive_relations = positive_relations[positive_mask]
            if negative_mask.sum() > 0:
                # positive_relations_shape = (10, 2)
                # negative_relations_shape = (90, 2)
                # reordered_relations_shape = (100, 2)
                reordered_relations = paddle.concat([positive_relations, negative_relations])
            else:
                reordered_relations = positive_relations

            relation_per_doc_label = paddle.zeros([len(reordered_relations), 1], dtype=reordered_relations.dtype)
            relation_per_doc_label[: len(positive_relations)] = 1
            # relation_per_doc shape: (100, 3)
            """
            relation_per_doc = 
            [[1 , 0 , 1 ],# 正样本
             [2 , 3 , 1 ],
             [4 , 5 , 1 ],
             ......
             [19, 18, 1 ],
             [1 , 3 , 0 ],# 负样本
             [1 , 5 , 0 ],
             ......
                        ]
            """
            relation_per_doc = paddle.concat([reordered_relations, relation_per_doc_label], axis=1)
            assert len(relation_per_doc[:, 0]) != 0
            # 第1个元素记录正负样本的长度信息,例如:[100, 100, 100]
            new_relations[b, 0] = paddle.shape(relation_per_doc)[0].astype(new_relations.dtype)
            # 将正负样本放到new_relations中
            new_relations[b, 1 : len(relation_per_doc) + 1] = relation_per_doc
            # new_relations.append(relation_per_doc)
        return new_relations, entities

    def get_predicted_relations(self, logits, relations, entities):
        """
            logits: 预测得到的关系概率, 例如:shape = (100, 2)
            relations: shape = (100, 3)
            entities:  shape = (513, 3)
        """
        pred_relations = []
        for i, pred_label in enumerate(logits.argmax(-1)):
            if pred_label != 1:
                continue
            rel = paddle.full(shape=[7, 2], fill_value=-1, dtype=relations.dtype)
            rel[0, 0] = relations[:, 0][i]
            rel[1, 0] = entities[:, 0][relations[:, 0][i] + 1]
            rel[1, 1] = entities[:, 1][relations[:, 0][i] + 1]
            rel[2, 0] = entities[:, 2][relations[:, 0][i] + 1]
            rel[3, 0] = relations[:, 1][i]
            rel[4, 0] = entities[:, 0][relations[:, 1][i] + 1]
            rel[4, 1] = entities[:, 1][relations[:, 1][i] + 1]
            rel[5, 0] = entities[:, 2][relations[:, 1][i] + 1]
            rel[6, 0] = 1
            pred_relations.append(rel)
        return pred_relations

    def forward(self, hidden_states, entities, relations):
        """
            hidden_states_shape:(bs, 512, 768)
            entities_shape: (bs, 513, 3)    , 其中:513 = 512 + 1,第一个元素记录长度信息
            relations_shape: (bs, 262145, 2),其中:262145 = 512*512 + 1,第一个元素记录长度信息
        """
        batch_size, max_length, _ = paddle.shape(entities)
        # 1、构建关系的正负样本
        # relations_shape: (bs, 263169, 3) , 其中: 263169 = 513 * 513
        # entities_shape: (bs, 513, 3)
        relations, entities = self.build_relation(relations, entities)
        loss = 0
        # 所有预测关系结果
        all_pred_relations = paddle.full(
            shape=[batch_size, max_length * max_length, 7, 2], fill_value=-1, dtype=entities.dtype
        )
        for b in range(batch_size):
            # 2、获取关系头(question)、关系尾(answer)对应的特征信息
            # 取出正负样本关系对, relation_shape = (100, 3)
            relation = relations[b, 1 : relations[b, 0, 0] + 1]
            # 获取关系头(question)、关系尾(answer)、以及关系标签(1表示question和answer是一对,即正样本, 0表示负样本)
            head_entities = relation[:, 0]
            tail_entities = relation[:, 1]
            relation_labels = relation[:, 2]
            # 每一个实体(question或answer)在input_ids中开始的索引
            # 例如:  [0  , 3  , 4  , 8  , 14 , 16 , 23 , 29 , 34 , 37 , 60 , 65 , 82 , 84 ,
            #         87 , 90 , 91 , 96 , 102, 106]
            entities_start_index = paddle.to_tensor(entities[b, 1 : entities[b, 0, 0] + 1, 0])
            # 获取每个实体类型编号,1表示question,2表示answer
            # 例如:[2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]
            entities_labels = paddle.to_tensor(entities[b, 1 : entities[b, 0, 2] + 1, 2])
            # 获取关系头(即question)在input_ids中开始的索引,为了后面获取对应token的hidden_states
            head_index = entities_start_index[head_entities]
            # 获取关系头(question)对应的实体类型编号
            head_label = entities_labels[head_entities]
            # 关系头(question)经过Embedding, head_label_repr_shape = (100, 768)
            head_label_repr = self.entity_emb(head_label)

            # 获取关系尾(即answer)在input_ids中开始的索引,为了后面获取对应token的hidden_states
            tail_index = entities_start_index[tail_entities]
            # 获取关系尾(answer)对应的实体类型编号
            tail_label = entities_labels[tail_entities]
            # 关系尾(answer)经过Embedding, tail_label_repr_shape = (100, 768)
            tail_label_repr = self.entity_emb(tail_label)

            # 获取关系头(question)开始token的hidden_states, tmp_hidden_states shape: (100, 768)
            tmp_hidden_states = hidden_states[b][head_index]
            if len(tmp_hidden_states.shape) == 1:
                tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)
            #  concat, head_repr_shape = (100, 1536)
            head_repr = paddle.concat((tmp_hidden_states, head_label_repr), axis=-1)

            # 获取关系尾(answer)开始token的hidden_states, tmp_hidden_states shape: (100, 768)
            tmp_hidden_states = hidden_states[b][tail_index]
            if len(tmp_hidden_states.shape) == 1:
                tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)
            #  concat, tail_repr_shape = (100, 1536)
            tail_repr = paddle.concat((tmp_hidden_states, tail_label_repr), axis=-1)

            # 3、利用提取到的head_repr、tail_repr进行关系分类
            # heads_shape = (100, 1536) -> (100, 384)
            # tails_shape = (100, 1536) -> (100, 384)
            heads = self.ffnn_head(head_repr)
            tails = self.ffnn_tail(tail_repr)
            # 结合双线性层和线性层,实现对两个输入向量的复杂交互建模
            # logits_shape = (100, 2)
            logits = self.rel_classifier(heads, tails)

            # 4、计算交叉熵损失
            loss += self.loss_fct(logits, relation_labels)
            pred_relations = self.get_predicted_relations(logits, relation, entities[b])
            if len(pred_relations) > 0:
                pred_relations = paddle.stack(pred_relations)
                all_pred_relations[b, 0, :, :] = paddle.shape(pred_relations)[0].astype(all_pred_relations.dtype)
                all_pred_relations[b, 1 : len(pred_relations) + 1, :, :] = pred_relations
        return loss, all_pred_relations
相关推荐
南城花随雪。15 分钟前
蚁群算法(Ant Colony Optimization)详细解读
算法
lLinkl22 分钟前
Java面试经典 150 题.P27. 移除元素(002)
算法
tangguofeng28 分钟前
合并排序算法(C语言版)
算法
ChaoZiLL1 小时前
关于我的数据结构与算法——初阶第二篇(排序)
数据结构·算法
爱编程的古惑仔1 小时前
leetcode刷题笔记——15.三数之和
笔记·算法·leetcode
中杯可乐多加冰1 小时前
【AI应用落地实战】智能文档处理本地部署——可视化文档解析前端TextIn ParseX实践
人工智能·深度学习·大模型·ocr·智能文档处理·acge·textin
T0uken2 小时前
【机器学习】Softmax 函数
神经网络·机器学习·分类
MogulNemenis2 小时前
随机题两题
java·后端·学习·算法
single5942 小时前
【综合算法学习】(第十篇)
java·数据结构·c++·vscode·学习·算法·leetcode
TangKenny2 小时前
荒岛逃生游戏
算法·游戏