【NLP 36、CRF条件随机场 —— 源码解读】

目录

[一、CRF ------ 条件随机场:](#一、CRF —— 条件随机场:)

[1.CRF - 转移矩阵](#1.CRF - 转移矩阵)

2.发射矩阵

3.结合发射矩阵和转移矩阵

[4.CRF ------ Loss定义](#4.CRF —— Loss定义)

[二、CRF ------ 源码解读](#二、CRF —— 源码解读)

1.初始化CRF模块

2.随机初始化CRF参数

3.前向计算

4.维特比算法解码

5.验证输入张量

6.计算分数

7.计算归一化因子

8.解码标签序列

9.源码

核心算法


难只是暂时的,顶住压力后,希望能见到你心中的那束太阳

------ 25.3.11

一、CRF ------ 条件随机场:

CRF是为了解决,当预测某一个字为一种实体的左边界时,则其右边不可能是其余实体的内部或右边界,我们运用另一个矩阵控制序列前后转移的概率(相关性)

CRF的本质是在神经网络中加入一个CRF - 转移矩阵

1.CRF - 转移矩阵

**CRF - 转移矩阵:**标签数量 × 标签数量,本质上学习的是字和字之间两两标签转移的概率

START 和 END 可以看作两个特殊的符号,标记句子的开始和句子的结束


2.发射矩阵

**发射矩阵:**对于一句话中的每一个字进行四分类预测,判断其作为词的左右边界、词的内部、单字的概率。


3.结合发射矩阵和转移矩阵

CRF - 转移矩阵可以分别学习到某个类别的字转移到其他类别字的概率,然后与 发射矩阵学习到的输入向量过神经网络预测到的两字间的概率值相加,总和进行比较,对输入序列进行预测

CRF - 条件随机场输出的转移矩阵 可以与 向量经过神经网络后得出的发射矩阵结合 使用,输出一个更优的预测结果

转移矩阵可以影响发射矩阵的结果,相当于在神经网络结构中加入一层神经网络

**作用:**规避一些不合理的序列输出


4.CRF ------ Loss定义

① 输入序列 X,输出序列为 y的路径分数:A 为转移矩阵(代表前一个字向后一个字转移的概率),P 为发射矩阵(过神经网络的每个字对应的概率值),**s(X, y)**代表任意一条路径的正确概率得分

s(X, y) = log(A * P) = logA + logP(这里的路径分数可以看作结合两矩阵,再做 log 运算后的)

② 输入序列X,预测输出序列为y的概率:对上式做softmax,对 步骤 ① 得到的所有路径分数做归一化

③ 对上式取log,目标为最大化该值(方便计算,与 p(y | X) 成正比):

依然希望这个log (p(y | X)) 路径分数是最大的

④ 对上式取相反数做loss,目标为最小化该值:

其他路径的总概率得分之和的 log 值 **-**正确路径的总概率得分

CRF会明显拖慢训练速度,以效率的角度考虑可以不使用CRF

序列标注任务需要位置对应

而如果使用Bert模型,则做序列标注任务时,label标签在前后都需要加一个占位符,将Bert模型的CLS和SCP标识符包括

文本分类任务与序列标注任务模型结构的主要区别:pooling 归一化层


二、CRF ------ 源码解读

1.初始化CRF模块

初始化CRF模块

  • 检查 num_tags 是否有效。
  • 定义 start_transitions(起始转移分数)、end_transitions(结束转移分数)和 transitions(标签间转移分数)为可训练参数。
  • 调用 reset_parameters 初始化这些参数。

**num_tags:**标签的数量,定义CRF的标签空间大小

**batch_first:**输入张量的第一个维度是否为批次大小,控制输入张量的维度顺序

**start_transitions:**起始状态的转移分数,大小为num_tags,表示从开始状态到每个标签的转移分数

**end_transitions:**结束状态的转移分数,大小为num_tags,表示从每个标签到结束状态的转移分数

**transitions:**标签之间的转移分数,大小为num_tags, num_tags,表示标签之间的转移概率

**nn.Parameter():**将张量标记为模型参数,使其在训练过程中可以被优化

参数 类型 描述
data torch.Tensor 要标记为参数的张量。
requires_grad bool 是否计算梯度(默认 True

**torch.empty():**创建一个未初始化的张量

参数 类型 描述
size tuple 张量的形状。
dtype torch.dtype 张量的数据类型(可选)。
device torch.device 张量的设备(可选)。

**raise ValueError():**抛出一个值错误异常。

参数 类型 描述
message str 错误信息。
python 复制代码
    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
        self.reset_parameters()

2.随机初始化CRF参数

随机初始化CRF的参数

  • 使用均匀分布(范围 [-0.1, 0.1])初始化 start_transitionsend_transitionstransitions

**nn.init.uniform_():**用均匀分布初始化张量

参数 类型 描述
tensor torch.Tensor 要初始化的张量。
a float 均匀分布的下界。
b float 均匀分布的上界。
python 复制代码
    def reset_parameters(self) -> None:
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

3.前向计算

Pytorch封装好的CRF模型的forward前向计算过程中,计算的是正确路径的概率,作为Loss使用需要添加负号,用概率的相反数作为损失

计算给定标签序列的对数似然

  • 验证输入张量的形状和有效性。
  • 如果 mask 为空,创建一个全 1 的掩码。
  • 如果 batch_firstTrue,调整张量的维度顺序。
  • 计算标签序列的分数(numerator)和归一化因子(denominator)。
  • 计算对数似然(llh = numerator - denominator)。
  • 根据 reduction 参数返回结果(summeantoken_mean

**emissions:**发射分数张量,大小为 (seq_length, batch_size, num_tags)(batch_size, seq_length, num_tags)

**tags:**标签序列张量,大小为 (seq_length, batch_size)(batch_size, seq_length)

**mask:**掩码张量,大小为 (seq_length, batch_size)(batch_size, seq_length)

**reduction:**输出的缩减方式,可选 nonesummeantoken_mean

**transpose():**交换张量的两个维度

参数 类型 描述
dim0 int 第一个维度。
dim1 int 第二个维度。

**raise ValueError():**抛出一个值错误异常。

参数 类型 描述
message str 错误信息。

**torch.ones_like():**创建一个与输入张量形状相同的全 1 张量。

参数 类型 描述
input torch.Tensor 输入张量。
dtype torch.dtype 张量的数据类型(可选)。
device torch.device 张量的设备(可选)。

**sum():**计算张量的元素和

参数 类型 描述
dim int 沿指定维度求和(可选)。
keepdim bool 是否保持维度(可选)。

**float():**将张量转换为浮点类型。

python 复制代码
    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        numerator = self._compute_score(emissions, tags, mask)
        denominator = self._compute_normalizer(emissions, mask)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

4.维特比算法解码

使用 Viterbi 算法解码最可能的标签序列

  • 验证输入张量的形状和有效性。
  • 如果 mask 为空,创建一个全 1 的掩码。
  • 如果 batch_firstTrue,调整张量的维度顺序。
  • 调用 _viterbi_decode 解码最可能的标签序列。

**emissions:**发射分数张量,大小为 (seq_length, batch_size, num_tags)(batch_size, seq_length, num_tags)

**mask:**掩码张量,大小为 (seq_length, batch_size)(batch_size, seq_length)

**transpose():**交换张量的两个维度。

参数 类型 描述
dim0 int 第一个维度。
dim1 int 第二个维度。
python 复制代码
    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

5.验证输入张量

验证输入张量的形状和有效性

  • 检查 emissions 的维度是否为 3。
  • 检查 emissions 的最后一个维度是否等于 num_tags
  • 检查 emissionstagsmask 的形状是否匹配。
  • 检查 mask 的第一个时间步是否全部为 1。

**emissions:**发射分数张量

**tags:**标签序列张量

**mask:**掩码张量

**raise ValueError():**抛出一个值错误异常。

参数 类型 描述
message str 错误信息。

**.dim():**返回张量的维度数。

**.size():**返回张量的形状。

参数 类型 描述
dim int 指定维度(可选)。

**tuple():**将可迭代对象转换为元组

参数 类型 描述
iterable iterable 要转换的可迭代对象。

**.shape():**返回张量的形状(与 .size() 相同)。

python 复制代码
    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

6.计算分数

计算给定标签序列的分数

  • 验证输入张量的形状和有效性。
  • 初始化分数为起始转移分数和第一个时间步的发射分数。
  • 遍历时间步,累加转移分数和发射分数。
  • 添加结束转移分数。
  • 返回标签序列的分数。

**emissions:**发射分数张量

**tags:**标签序列张量

**mask:**掩码张量

**.dim():**返回张量的维度数。

**.size():**返回张量的形状。

参数 类型 描述
dim int 指定维度(可选)。

**.shape():**返回张量的形状(与 .size() 相同)。

**.float():**将张量转换为浮点类型。

**torch.arange():**创建一个等差序列张量

参数 类型 描述
start int 起始值(默认 0)。
end int 结束值(不包括)。
step int 步长(默认 1)。
dtype torch.dtype 张量的数据类型(可选)。
device torch.device 张量的设备(可选)。
python 复制代码
    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        seq_ends = mask.long().sum(dim=0) - 1
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        score += self.end_transitions[last_tags]

        return score

7.计算归一化因子

计算归一化因子(配分函数)

  • 验证输入张量的形状和有效性。
  • 初始化分数为起始转移分数和第一个时间步的发射分数。
  • 遍历时间步,计算所有可能标签序列的分数。
  • 使用 logsumexp 计算归一化因子。
  • 返回归一化因子。

**emissions:**发射分数张量

**mask:**掩码张量

**.all():**检查张量中的所有元素是否为真。

参数 类型 描述
dim int 沿指定维度检查(可选)。
keepdim bool 是否保持维度(可选)。

**.dim():**返回张量的维度数。

**.size():**返回张量的形状。

参数 类型 描述
dim int 指定维度(可选)。

**.shape():**返回张量的形状(与 .size() 相同)。

**torch.logsumexp():**计算张量的对数求和指数

参数 类型 描述
input torch.Tensor 输入张量。
dim int 沿指定维度计算(可选)。
keepdim bool 是否保持维度(可选)。

**torch.where():**根据条件选择元素。

参数 类型 描述
condition torch.Tensor 条件张量。
x torch.Tensor 条件为真时的值。
y torch.Tensor 条件为假时的值。
python 复制代码
    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            broadcast_score = score.unsqueeze(2)
            broadcast_emissions = emissions[i].unsqueeze(1)
            next_score = broadcast_score + self.transitions + broadcast_emissions
            next_score = torch.logsumexp(next_score, dim=1)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        score += self.end_transitions
        return torch.logsumexp(score, dim=1)

8.解码标签序列

使用 Viterbi 算法解码最可能的标签序列

  • 验证输入张量的形状和有效性。
  • 初始化分数为起始转移分数和第一个时间步的发射分数。
  • 遍历时间步,计算所有可能标签序列的分数,并记录最大分数的索引。
  • 添加结束转移分数。
  • 回溯 Viterbi 路径,找到最可能的标签序列。
  • 返回最可能的标签序列。

**emissions:**发射分数张量

**mask:**掩码张量

**.dim():**返回张量的维度数。

**.size():**返回张量的形状。

参数 类型 描述
dim int 指定维度(可选)。

**.shape():**返回张量的形状(与 .size() 相同)。

**unsqueeze():**在指定维度上增加一个大小为 1 的维度。

参数 类型 描述
dim int 要增加维度的位置。

**max():**返回张量的最大值,或沿指定维度的最大值及其索引。

参数 类型 描述
dim int 沿指定维度计算最大值(可选)。
keepdim bool 是否保持维度(可选)。

**torch.where():**根据条件选择元素。

参数 类型 描述
condition torch.Tensor 条件张量。
x torch.Tensor 条件为真时的值。
y torch.Tensor 条件为假时的值。

**long().sum():**将张量转换为长整型并求和

参数 类型 描述
dim int 沿指定维度求和(可选)。
keepdim bool 是否保持维度(可选)。

**append():**在列表末尾添加元素

参数 类型 描述
item any 要添加的元素。

**reverse():**反转列表中的元素顺序。

python 复制代码
    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        score = self.start_transitions + emissions[0]
        history = []

        for i in range(1, seq_length):
            broadcast_score = score.unsqueeze(2)
            broadcast_emission = emissions[i].unsqueeze(1)
            next_score = broadcast_score + self.transitions + broadcast_emission
            next_score, indices = next_score.max(dim=1)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        score += self.end_transitions

        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

9.源码

核心算法

  1. CRF 的对数似然计算

    • 对数似然 = 标签序列的分数 - 归一化因子。
    • 标签序列的分数 = 起始转移分数 + 发射分数 + 转移分数 + 结束转移分数。
    • 归一化因子 = 所有可能标签序列的分数之和(使用 logsumexp 计算)。
  2. Viterbi 算法

    • 动态规划算法,用于找到最可能的标签序列。
    • 在每个时间步,计算所有可能标签序列的分数,并记录最大分数的索引。
    • 回溯路径,得到最可能的标签序列。
python 复制代码
__version__ = '0.7.2'

from typing import List, Optional

import torch
import torch.nn as nn


class CRF(nn.Module):
    """Conditional random field.

    This module implements a conditional random field [LMP01]_. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has `~CRF.decode` method which finds
    the best tag sequence given an emission score tensor using `Viterbi algorithm`_.

    Args:
        num_tags: Number of tags.
        batch_first: Whether the first dimension corresponds to the size of a minibatch.

    Attributes:
        start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
            ``(num_tags,)``.
        end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
            ``(num_tags,)``.
        transitions (`~torch.nn.Parameter`): Transition score tensor of size
            ``(num_tags, num_tags)``.


    .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
       "Conditional random fields: Probabilistic models for segmenting and
       labeling sequence data". *Proc. 18th International Conf. on Machine
       Learning*. Morgan Kaufmann. pp. 282--289.

    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        # 随机初始化参数
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.

        The parameters will be initialized randomly from a uniform distribution
        between -0.1 and 0.1.
        """
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.

        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            tags (`~torch.LongTensor`): Sequence of tags tensor of size
                ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            reduction: Specifies  the reduction to apply to the output:
                ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
                ``sum``: the output will be summed over batches. ``mean``: the output will be
                averaged over batches. ``token_mean``: the output will be averaged over tokens.

        Returns:
            `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
            reduction is ``none``, ``()`` otherwise.
        """
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """Find the most likely tag sequence using Viterbi algorithm.

        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.

        Returns:
            List of list containing the best tag sequence for each batch.
        """
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list
相关推荐
要努力啊啊啊23 分钟前
GaLore:基于梯度低秩投影的大语言模型高效训练方法详解一
论文阅读·人工智能·语言模型·自然语言处理
天才测试猿23 分钟前
接口自动化测试之pytest接口关联框架封装
自动化测试·软件测试·python·测试工具·职场和发展·测试用例·pytest
先做个垃圾出来………34 分钟前
《机器学习系统设计》
人工智能·机器学习
s1533540 分钟前
6.RV1126-OPENCV 形态学基础膨胀及腐蚀
人工智能·opencv·计算机视觉
jndingxin42 分钟前
OpenCV CUDA模块特征检测------角点检测的接口createMinEigenValCorner()
人工智能·opencv·计算机视觉
先做个垃圾出来………1 小时前
Python中使用pandas
开发语言·python·pandas
Tianyanxiao1 小时前
宇树科技更名“股份有限公司”深度解析:机器人企业IPO前奏与资本化路径
人工智能
道可云1 小时前
道可云人工智能每日资讯|北京农业人工智能与机器人研究院揭牌
人工智能·机器人·ar·deepseek
不爱吃山楂罐头1 小时前
第三十三天打卡复习
python·深度学习
艾醒(AiXing-w)2 小时前
探索大语言模型(LLM):参数量背后的“黄金公式”与Scaling Law的启示
人工智能·语言模型·自然语言处理