SSAN代码解析

文章目录

run_docred.py

下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

python 复制代码
def load_and_cache_examples(args, tokenizer, evaluate=False, predict=False):
    # 如果在分布式训练中,并且不是第一个进程且不在评估模式下,则等待第一个进程处理数据集
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # 确保只有第一个进程处理数据集,其余进程使用缓存

    processor = DocREDProcessor()  # 初始化DocRED处理器
    # 加载数据
    logger.info("Creating features from dataset file at %s", args.data_dir)
    label_map = processor.get_label_map(args.data_dir)  # 获取标签映射

    # 根据evaluate和predict标志决定加载哪种数据集
    if evaluate:
        examples = processor.get_dev_examples(args.data_dir)  # 获取验证集样本
    elif predict:
        examples = processor.get_test_examples(args.data_dir)  # 获取测试集样本
    else:
        examples = processor.get_train_examples(args.data_dir)  # 获取训练集样本

    # 将样本转换为特征
    features = convert_examples_to_features(
        examples,
        args.model_type,
        tokenizer,
        max_length=args.max_seq_length,
        max_ent_cnt=args.max_ent_cnt,
        label_map=label_map
    )

    # 如果在分布式训练中,并且是第一个进程且不在评估模式下,则等待处理完数据集
    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # 确保只有第一个进程处理数据集,其余进程使用缓存

    # 将特征转换为Tensor并构建数据集
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)  # 输入IDs
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)  # 注意力掩码
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)  # 句子类型IDs
    all_ent_mask = torch.tensor([f.ent_mask for f in features], dtype=torch.float)  # 实体掩码
    all_ent_ner = torch.tensor([f.ent_ner for f in features], dtype=torch.long)  # 实体命名实体识别(NER)标签
    all_ent_pos = torch.tensor([f.ent_pos for f in features], dtype=torch.long)  # 实体位置
    all_ent_distance = torch.tensor([f.ent_distance for f in features], dtype=torch.long)  # 实体距离
    all_structure_mask = torch.tensor([f.structure_mask for f in features], dtype=torch.bool)  # 结构掩码
    all_label = torch.tensor([f.label for f in features], dtype=torch.bool)  # 标签
    all_label_mask = torch.tensor([f.label_mask for f in features], dtype=torch.bool)  # 标签掩码

    # 创建TensorDataset
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids,
                            all_ent_mask, all_ent_ner, all_ent_pos, all_ent_distance,
                            all_structure_mask, all_label, all_label_mask)

    return dataset  # 返回构建的数据集

详细介绍每一行代码

  1. 函数定义:

    python 复制代码
    def load_and_cache_examples(args, tokenizer, evaluate=False, predict=False):
    • 功能: 加载并缓存样本数据。
    • 输入 : args(参数配置),tokenizer(分词器),evaluate(是否评估),predict(是否预测)。
    • 输出: 返回构建的TensorDataset对象。
  2. 处理分布式训练的屏障:

    python 复制代码
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()
    • 功能: 确保只有第一个进程处理数据集,其余进程等待使用缓存。
  3. 初始化处理器和加载标签映射:

    python 复制代码
    processor = DocREDProcessor()
    logger.info("Creating features from dataset file at %s", args.data_dir)
    label_map = processor.get_label_map(args.data_dir)
    • 功能: 初始化DocRED处理器并加载标签映射。
  4. 根据模式加载数据集:

    python 复制代码
    if evaluate:
        examples = processor.get_dev_examples(args.data_dir)
    elif predict:
        examples = processor.get_test_examples(args.data_dir)
    else:
        examples = processor.get_train_examples(args.data_dir)
    • 功能 : 根据evaluatepredict标志,加载验证集、测试集或训练集的样本。
  5. 将样本转换为特征:

    python 复制代码
    features = convert_examples_to_features(
        examples,
        args.model_type,
        tokenizer,
        max_length=args.max_seq_length,
        max_ent_cnt=args.max_ent_cnt,
        label_map=label_map
    )
    • 功能: 将样本转换为特征,包括输入IDs、注意力掩码、句子类型IDs等。
  6. 处理分布式训练的屏障(第二次):

    python 复制代码
    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()
    • 功能: 确保只有第一个进程处理数据集,其余进程等待使用缓存。
  7. 将特征转换为Tensor并构建数据集:

    python 复制代码
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_ent_mask = torch.tensor([f.ent_mask for f in features], dtype=torch.float)
    all_ent_ner = torch.tensor([f.ent_ner for f in features], dtype=torch.long)
    all_ent_pos = torch.tensor([f.ent_pos for f in features], dtype=torch.long)
    all_ent_distance = torch.tensor([f.ent_distance for f in features], dtype=torch.long)
    all_structure_mask = torch.tensor([f.structure_mask for f in features], dtype=torch.bool)
    all_label = torch.tensor([f.label for f in features], dtype=torch.bool)
    all_label_mask = torch.tensor([f.label_mask for f in features], dtype=torch.bool)
    • 功能: 将每个特征属性转换为Tensor格式。
  8. 创建TensorDataset:

    python 复制代码
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids,
                            all_ent_mask, all_ent_ner, all_ent_pos, all_ent_distance,
                            all_structure_mask, all_label, all_label_mask)
    • 功能: 使用上述Tensor创建一个TensorDataset对象。
  9. 返回构建的数据集:

    python 复制代码
    return dataset
    • 功能: 返回创建的TensorDataset对象。

以上代码实现了数据的加载、处理和转换,最终返回一个PyTorch中的TensorDataset对象,可用于模型训练、评估或预测。

dataset.py

当然,下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

python 复制代码
def norm_mask(input_mask):
    # 创建一个与input_mask形状相同的全零矩阵
    output_mask = numpy.zeros(input_mask.shape)
    
    # 遍历input_mask的每一行
    for i in range(len(input_mask)):
        # 如果当前行的所有元素不全为零
        if not numpy.all(input_mask[i] == 0):
            # 将当前行除以当前行元素之和,并将结果存储在output_mask的相应位置
            output_mask[i] = input_mask[i] / sum(input_mask[i])
    
    # 返回归一化后的掩码矩阵
    return output_mask

详细介绍每一行代码

  1. 函数定义:

    python 复制代码
    def norm_mask(input_mask):
    • 功能 : 定义一个函数norm_mask,用于归一化输入掩码矩阵的每一行。
    • 输入 : input_mask(numpy数组,形状为(m, n),每行表示一个掩码)。
    • 输出: 返回一个与输入掩码矩阵相同形状的归一化后的掩码矩阵。
  2. 创建一个与input_mask形状相同的全零矩阵:

    python 复制代码
    output_mask = numpy.zeros(input_mask.shape)
    • 功能 : 初始化一个与input_mask形状相同的全零矩阵output_mask,用于存储归一化后的结果。
  3. 遍历input_mask的每一行:

    python 复制代码
    for i in range(len(input_mask)):
    • 功能 : 使用for循环遍历input_mask的每一行。
  4. 检查当前行的所有元素是否不全为零:

    python 复制代码
    if not numpy.all(input_mask[i] == 0):
    • 功能 : 使用numpy.all检查当前行的所有元素是否不全为零。如果当前行的所有元素都为零,则跳过该行。
  5. 归一化当前行并存储在output_mask的相应位置:

    python 复制代码
    output_mask[i] = input_mask[i] / sum(input_mask[i])
    • 功能 : 将当前行的每个元素除以该行元素的总和,得到归一化后的结果,并存储在output_mask的相应位置。
  6. 返回归一化后的掩码矩阵:

    python 复制代码
    return output_mask
    • 功能 : 返回output_mask,即归一化后的掩码矩阵。

输入和输出

  • 输入 :
    • input_mask:一个形状为(m, n)的numpy数组,表示多个掩码矩阵。
  • 输出 :
    • output_mask:一个形状为(m, n)的numpy数组,表示归一化后的掩码矩阵。每一行的元素之和为1(如果该行不全为零)。

示例

假设input_mask为:

python 复制代码
input_mask = numpy.array([
    [1, 2, 3],
    [0, 0, 0],
    [4, 5, 6]
])

调用norm_mask函数:

python 复制代码
output_mask = norm_mask(input_mask)

得到的output_mask为:

python 复制代码
output_mask = numpy.array([
    [0.16666667, 0.33333333, 0.5       ],
    [0.        , 0.        , 0.        ],
    [0.26666667, 0.33333333, 0.4       ]
])

在上述示例中,第一行和第三行的元素被归一化为它们各自的总和,而第二行保持不变(全零)。

docred_convert_examples_to_features

python 复制代码
def docred_convert_examples_to_features(
    examples,
    model_type,
    tokenizer,
    max_length=512,
    max_ent_cnt=42,
    label_map=None,
    pad_token=0,
):
    # 初始化存储特征的列表
    features = []

    # 命名实体识别标签映射
    ner_map = {'PAD':0, 'ORG':1, 'LOC':2, 'NUM':3, 'TIME':4, 'MISC':5, 'PER':6}
    
    # 初始化距离桶,用于实体间距离编码
    distance_buckets = numpy.zeros((512), dtype='int64')
    distance_buckets[1] = 1
    distance_buckets[2:] = 2
    distance_buckets[4:] = 3
    distance_buckets[8:] = 4
    distance_buckets[16:] = 5
    distance_buckets[32:] = 6
    distance_buckets[64:] = 7
    distance_buckets[128:] = 8
    distance_buckets[256:] = 9

    # 遍历每个样本
    for (ex_index, example) in enumerate(examples):
        len_examples = len(examples)

        # 每处理500个样本,打印日志信息
        if ex_index % 500 == 0:
            logger.info("Writing example %d/%d" % (ex_index, len_examples))

        # 初始化存储token的列表,以及token到句子和单词的映射
        input_tokens = []
        tok_to_sent = []
        tok_to_word = []
        
        # 遍历每个句子
        for sent_idx, sent in enumerate(example.sents):
            # 遍历句子中的每个单词
            for word_idx, word in enumerate(sent):
                tokens_tmp = tokenizer.tokenize(word, add_prefix_space=True)
                input_tokens += tokens_tmp
                tok_to_sent += [sent_idx] * len(tokens_tmp)
                tok_to_word += [word_idx] * len(tokens_tmp)

        # 如果token数量小于等于最大长度减去2
        if len(input_tokens) <= max_length - 2:
            # 根据模型类型添加特殊token
            if model_type == 'roberta':
                input_tokens = [tokenizer.bos_token] + input_tokens + [tokenizer.eos_token]
            else:
                input_tokens = [tokenizer.cls_token] + input_tokens + [tokenizer.sep_token]
            tok_to_sent = [None] + tok_to_sent + [None]
            tok_to_word = [None] + tok_to_word + [None]
            input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
            attention_mask = [1] * len(input_ids)
            token_type_ids = [0] * len(input_ids)
            # padding
            padding = [None] * (max_length - len(input_ids))
            tok_to_sent += padding
            tok_to_word += padding
            padding = [0] * (max_length - len(input_ids))
            attention_mask += padding
            token_type_ids += padding
            padding = [pad_token] * (max_length - len(input_ids))
            input_ids += padding
        else:
            # 截断超长的token序列
            input_tokens = input_tokens[:max_length - 2]
            tok_to_sent = tok_to_sent[:max_length - 2]
            tok_to_word = tok_to_word[:max_length - 2]
            # 根据模型类型添加特殊token
            if model_type == 'roberta':
                input_tokens = [tokenizer.bos_token] + input_tokens + [tokenizer.eos_token]
            else:
                input_tokens = [tokenizer.cls_token] + input_tokens + [tokenizer.sep_token]
            tok_to_sent = [None] + tok_to_sent + [None]
            tok_to_word = [None] + tok_to_word + [None]
            input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
            attention_mask = [1] * len(input_ids)
            token_type_ids = [pad_token] * len(input_ids)

        # 实体掩码和NER / 共指特征
        ent_mask = numpy.zeros((max_ent_cnt, max_length), dtype='int64')
        ent_ner = [0] * max_length
        ent_pos = [0] * max_length
        tok_to_ent = [-1] * max_length
        ents = example.vertexSet
        for ent_idx, ent in enumerate(ents):
            for mention in ent:
                for tok_idx in range(len(input_ids)):
                    if tok_to_sent[tok_idx] == mention['sent_id'] and mention['pos'][0] <= tok_to_word[tok_idx] < mention['pos'][1]:
                        ent_mask[ent_idx][tok_idx] = 1
                        ent_ner[tok_idx] = ner_map[ent[0]['type']]
                        ent_pos[tok_idx] = ent_idx + 1
                        tok_to_ent[tok_idx] = ent_idx

        # 距离特征
        ent_first_appearance = [0] * max_ent_cnt
        ent_distance = numpy.zeros((max_ent_cnt, max_ent_cnt), dtype='int8')  # padding id is 10
        for i in range(len(ents)):
            if numpy.all(ent_mask[i] == 0):
                continue
            else:
                ent_first_appearance[i] = numpy.where(ent_mask[i] == 1)[0][0]
        for i in range(len(ents)):
            for j in range(len(ents)):
                if ent_first_appearance[i] != 0 and ent_first_appearance[j] != 0:
                    if ent_first_appearance[i] >= ent_first_appearance[j]:
                        ent_distance[i][j] = distance_buckets[ent_first_appearance[i] - ent_first_appearance[j]]
                    else:
                        ent_distance[i][j] = -distance_buckets[-ent_first_appearance[i] + ent_first_appearance[j]]
        ent_distance += 10  # norm from [-9, 9] to [1, 19]

        # 结构掩码
        structure_mask = numpy.zeros((5, max_length, max_length), dtype='float')
        for i in range(max_length):
            if attention_mask[i] == 0:
                break
            else:
                if tok_to_ent[i] != -1:
                    for j in range(max_length):
                        if tok_to_sent[j] is None:
                            continue
                        # intra
                        if tok_to_sent[j] == tok_to_sent[i]:
                            # intra-coref
                            if tok_to_ent[j] == tok_to_ent[i]:
                                structure_mask[0][i][j] = 1
                            # intra-relate
                            elif tok_to_ent[j] != -1:
                                structure_mask[1][i][j] = 1
                            # intra-NA
                            else:
                                structure_mask[2][i][j] = 1
                        # inter
                        else:
                            # inter-coref
                            if tok_to_ent[j] == tok_to_ent[i]:
                                structure_mask[3][i][j] = 1
                            # inter-relate
                            elif tok_to_ent[j] != -1:
                                structure_mask[4][i][j] = 1

        # 标签
        label_ids = numpy.zeros((max_ent_cnt, max_ent_cnt, len(label_map.keys())), dtype='bool')
        # 测试文件没有"labels"
        if example.labels is not None:
            labels = example.labels
            for label in labels:
                label_ids[label['h']][label['t']][label_map[label['r']]] = 1
        for h in range(len(ents)):
            for t in range(len(ents)):
                if numpy.all(label_ids[h][t] == 0):
                    label_ids[h][t][0] = 1

        # 标签掩码
        label_mask = numpy.zeros((max_ent_cnt, max_ent_cnt), dtype='bool')
        label_mask[:len(ents), :len(ents)] = 1
        for ent in range(len(ents)):
            label_mask[ent][ent] = 0
        for ent in range(len(ents)):
            if numpy.all(ent_mask[ent] == 0):
                label_mask[ent, :] = 0
                label_mask[:, ent] = 0

        # 归一化实体掩码
        ent_mask = norm_mask(ent_mask)

        # 断言检查特征维度
        assert len(input_ids) == max_length
        assert len(attention_mask) == max_length
        assert len(token_type_ids) == max_length
        assert ent_mask.shape == (max_ent_cnt, max_length)
        assert label_ids.shape == (max_ent_cnt, max_ent_cnt, len(label_map.keys()))
        assert label_mask.shape == (max_ent_cnt, max_ent_cnt)
        assert len(ent_ner) == max_length
        assert len(ent_pos) == max_length
        assert ent_distance.shape == (max_ent_cnt, max_ent_cnt)
        assert structure_mask.shape == (5, max_length, max_length)

        # 打印日志信息
        if ex_index == 42:
            logger.info("*** Example ***")
            logger.info("guid: %s" % example.guid)
            logger.info("doc: %s" % (' '.join([' '.join(sent) for sent in example.sents])))
            logger.info("input_ids: %s" % (" ".join([str(x) for x in input_ids])))
            logger.info("attention_mask: %s" % (" ".join([str(x) for x in attention_mask])))
            logger.info("token_type_ids: %s" % (" ".join([str(x) for x in token_type_ids])))
            logger.info("ent_mask for first ent: %s" % (" ".join([str(x) for x in ent_mask[0]])))
            logger.info("label for ent pair 0-1: %s" % (" ".join([str(x) for x in label_ids[0][1]])))
            logger.info("label_mask for first ent: %s" % (" ".join([str(x) for x in label_mask[0]])))
            logger.info("ent_ner: %s" % (" ".join([str(x) for x in ent_ner])))
            logger.info("ent_pos: %s" % (" ".join([str(x) for x in ent_pos])))
            logger.info("ent_distance for first ent: %s" % (" ".join([str(x) for x in ent_distance[0]])))

        # 添加特征到列表中
        features.append(
            DocREDInputFeatures(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                ent_mask=ent_mask,
                ent_ner=ent_ner,
                ent_pos=ent_pos,
                ent_distance=ent_distance,
                structure_mask=structure_mask,
                label=label_ids,
                label_mask=label_mask,
            )
        )

    return features  # 返回特征列表

下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

详细介绍每一行代码

  1. 函数定义 :

    python 复制代码
    def docred_convert_examples_to_features(
        examples,
        model_type,
        tokenizer,
        max_length=512,
        max_ent_cnt=42,
        label_map=None,
        pad_token=0,
    ):
    • 功能: 将DocRED样本转换为模型特征。
    • 输入 :
      • examples(样本列表)
      • model_type(模型类型)
      • tokenizer(分词器)
      • max_length(最大序列长度,默认为512)
      • max_ent_cnt(最大实体数量,默认为42)
      • label_map(标签映射,默认为None)
      • pad_token(填充标记,默认为0)
    • 输出: 返回特征列表。

DocREDProcessor

下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

python 复制代码
class DocREDProcessor(object):
    """Processor for the DocRED data set."""
    # DocRED数据集处理器类

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        # 从tensor字典中创建一个DocREDExample实例
        return DocREDExample(
            tensor_dict["guid"].numpy(),  # 样本的唯一标识符
            tensor_dict["title"].numpy(),  # 文档标题
            tensor_dict["vertexSet"].numpy(),  # 实体集
            tensor_dict["sents"].numpy(),  # 句子
            tensor_dict["labels"].numpy(),  # 标签
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载训练集样本
        with open(os.path.join(data_dir, "train_annotated.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')

    def get_distant_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载远程监督的训练集样本
        with open(os.path.join(data_dir, "train_distant.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')

    def get_dev_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载开发集样本
        with open(os.path.join(data_dir, "dev.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'dev')

    def get_test_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载测试集样本
        with open(os.path.join(data_dir, "test.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'test')

    def get_label_map(self, data_dir):
        """See base class."""
        # 从指定目录加载标签映射
        with open(os.path.join(data_dir, "label_map.json"), 'r') as f:
            label_map = json.load(f)
        return label_map

    def _create_examples(self, instances, set_type):
        """Creates examples for the training and dev sets."""
        # 创建训练和开发集的样本
        examples = []
        for (i, ins) in enumerate(instances):
            guid = "%s-%s" % (set_type, i)  # 生成样本的唯一标识符
            examples.append(DocREDExample(guid=guid,
                                          title=ins['title'],
                                          vertexSet=ins['vertexSet'],
                                          sents=ins['sents'],
                                          labels=ins['labels'] if set_type != "test" else None))
        return examples  # 返回创建的样本列表

详细介绍每一行代码

  1. 类定义:

    python 复制代码
    class DocREDProcessor(object):
    • 功能 : 定义一个DocREDProcessor类,用于处理DocRED数据集。
  2. 方法get_example_from_tensor_dict:

    python 复制代码
    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return DocREDExample(
            tensor_dict["guid"].numpy(),
            tensor_dict["title"].numpy(),
            tensor_dict["vertexSet"].numpy(),
            tensor_dict["sents"].numpy(),
            tensor_dict["labels"].numpy(),
        )
    • 功能 : 从一个包含tensor的字典中创建一个DocREDExample实例。
    • 输入 : tensor_dict(包含多个tensor的字典)。
    • 输出 : 返回一个DocREDExample实例。
  3. 方法get_train_examples:

    python 复制代码
    def get_train_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "train_annotated.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')
    • 功能: 从指定目录加载训练集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回训练集样本列表。
  4. 方法get_distant_examples:

    python 复制代码
    def get_distant_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "train_distant.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')
    • 功能: 从指定目录加载远程监督的训练集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回远程监督的训练集样本列表。
  5. 方法get_dev_examples:

    python 复制代码
    def get_dev_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "dev.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'dev')
    • 功能: 从指定目录加载开发集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回开发集样本列表。
  6. 方法get_test_examples:

    python 复制代码
    def get_test_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "test.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'test')
    • 功能: 从指定目录加载测试集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回测试集样本列表。
  7. 方法get_label_map:

    python 复制代码
    def get_label_map(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "label_map.json"), 'r') as f:
            label_map = json.load(f)
        return label_map
    • 功能: 从指定目录加载标签映射。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回标签映射。
  8. 方法_create_examples:

    python 复制代码
    def _create_examples(self, instances, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, ins) in enumerate(instances):
            guid = "%s-%s" % (set_type, i)
            examples.append(DocREDExample(guid=guid,
                                          title=ins['title'],
                                          vertexSet=ins['vertexSet'],
                                          sents=ins['sents'],
                                          labels=ins['labels'] if set_type != "test" else None))
        return examples
    • 功能: 创建训练和开发集的样本。
    • 输入 :
      • instances(样本实例列表)
      • set_type(数据集类型,'train'或'dev'或'test')
    • 输出: 返回创建的样本列表。

输入和输出

  • 方法get_example_from_tensor_dict:

    • 输入 : tensor_dict(包含多个tensor的字典)。
    • 输出 : 返回一个DocREDExample实例。
  • 方法get_train_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回训练集样本列表。
  • 方法get_distant_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回远程监督的训练集样本列表。
  • 方法get_dev_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回开发集样本列表。
  • 方法get_test_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回测试集样本列表。
  • 方法get_label_map:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回标签映射。
  • 方法_create_examples:

    • 输入 :
      • instances(样本实例列表)
      • set_type(数据集类型,'train'或'dev'或'test')
    • 输出: 返回创建的样本列表。
相关推荐
数据智能老司机3 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机4 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机4 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机4 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i4 小时前
drf初步梳理
python·django
每日AI新事件4 小时前
python的异步函数
python
这里有鱼汤5 小时前
miniQMT下载历史行情数据太慢怎么办?一招提速10倍!
前端·python
databook14 小时前
Manim实现脉冲闪烁特效
后端·python·动效
程序设计实验室15 小时前
2025年了,在 Django 之外,Python Web 框架还能怎么选?
python
倔强青铜三16 小时前
苦练Python第46天:文件写入与上下文管理器
人工智能·python·面试