SSAN代码解析

文章目录

run_docred.py

下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

python 复制代码
def load_and_cache_examples(args, tokenizer, evaluate=False, predict=False):
    # 如果在分布式训练中,并且不是第一个进程且不在评估模式下,则等待第一个进程处理数据集
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # 确保只有第一个进程处理数据集,其余进程使用缓存

    processor = DocREDProcessor()  # 初始化DocRED处理器
    # 加载数据
    logger.info("Creating features from dataset file at %s", args.data_dir)
    label_map = processor.get_label_map(args.data_dir)  # 获取标签映射

    # 根据evaluate和predict标志决定加载哪种数据集
    if evaluate:
        examples = processor.get_dev_examples(args.data_dir)  # 获取验证集样本
    elif predict:
        examples = processor.get_test_examples(args.data_dir)  # 获取测试集样本
    else:
        examples = processor.get_train_examples(args.data_dir)  # 获取训练集样本

    # 将样本转换为特征
    features = convert_examples_to_features(
        examples,
        args.model_type,
        tokenizer,
        max_length=args.max_seq_length,
        max_ent_cnt=args.max_ent_cnt,
        label_map=label_map
    )

    # 如果在分布式训练中,并且是第一个进程且不在评估模式下,则等待处理完数据集
    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # 确保只有第一个进程处理数据集,其余进程使用缓存

    # 将特征转换为Tensor并构建数据集
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)  # 输入IDs
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)  # 注意力掩码
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)  # 句子类型IDs
    all_ent_mask = torch.tensor([f.ent_mask for f in features], dtype=torch.float)  # 实体掩码
    all_ent_ner = torch.tensor([f.ent_ner for f in features], dtype=torch.long)  # 实体命名实体识别(NER)标签
    all_ent_pos = torch.tensor([f.ent_pos for f in features], dtype=torch.long)  # 实体位置
    all_ent_distance = torch.tensor([f.ent_distance for f in features], dtype=torch.long)  # 实体距离
    all_structure_mask = torch.tensor([f.structure_mask for f in features], dtype=torch.bool)  # 结构掩码
    all_label = torch.tensor([f.label for f in features], dtype=torch.bool)  # 标签
    all_label_mask = torch.tensor([f.label_mask for f in features], dtype=torch.bool)  # 标签掩码

    # 创建TensorDataset
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids,
                            all_ent_mask, all_ent_ner, all_ent_pos, all_ent_distance,
                            all_structure_mask, all_label, all_label_mask)

    return dataset  # 返回构建的数据集

详细介绍每一行代码

  1. 函数定义:

    python 复制代码
    def load_and_cache_examples(args, tokenizer, evaluate=False, predict=False):
    • 功能: 加载并缓存样本数据。
    • 输入 : args(参数配置),tokenizer(分词器),evaluate(是否评估),predict(是否预测)。
    • 输出: 返回构建的TensorDataset对象。
  2. 处理分布式训练的屏障:

    python 复制代码
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()
    • 功能: 确保只有第一个进程处理数据集,其余进程等待使用缓存。
  3. 初始化处理器和加载标签映射:

    python 复制代码
    processor = DocREDProcessor()
    logger.info("Creating features from dataset file at %s", args.data_dir)
    label_map = processor.get_label_map(args.data_dir)
    • 功能: 初始化DocRED处理器并加载标签映射。
  4. 根据模式加载数据集:

    python 复制代码
    if evaluate:
        examples = processor.get_dev_examples(args.data_dir)
    elif predict:
        examples = processor.get_test_examples(args.data_dir)
    else:
        examples = processor.get_train_examples(args.data_dir)
    • 功能 : 根据evaluatepredict标志,加载验证集、测试集或训练集的样本。
  5. 将样本转换为特征:

    python 复制代码
    features = convert_examples_to_features(
        examples,
        args.model_type,
        tokenizer,
        max_length=args.max_seq_length,
        max_ent_cnt=args.max_ent_cnt,
        label_map=label_map
    )
    • 功能: 将样本转换为特征,包括输入IDs、注意力掩码、句子类型IDs等。
  6. 处理分布式训练的屏障(第二次):

    python 复制代码
    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()
    • 功能: 确保只有第一个进程处理数据集,其余进程等待使用缓存。
  7. 将特征转换为Tensor并构建数据集:

    python 复制代码
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    all_ent_mask = torch.tensor([f.ent_mask for f in features], dtype=torch.float)
    all_ent_ner = torch.tensor([f.ent_ner for f in features], dtype=torch.long)
    all_ent_pos = torch.tensor([f.ent_pos for f in features], dtype=torch.long)
    all_ent_distance = torch.tensor([f.ent_distance for f in features], dtype=torch.long)
    all_structure_mask = torch.tensor([f.structure_mask for f in features], dtype=torch.bool)
    all_label = torch.tensor([f.label for f in features], dtype=torch.bool)
    all_label_mask = torch.tensor([f.label_mask for f in features], dtype=torch.bool)
    • 功能: 将每个特征属性转换为Tensor格式。
  8. 创建TensorDataset:

    python 复制代码
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids,
                            all_ent_mask, all_ent_ner, all_ent_pos, all_ent_distance,
                            all_structure_mask, all_label, all_label_mask)
    • 功能: 使用上述Tensor创建一个TensorDataset对象。
  9. 返回构建的数据集:

    python 复制代码
    return dataset
    • 功能: 返回创建的TensorDataset对象。

以上代码实现了数据的加载、处理和转换,最终返回一个PyTorch中的TensorDataset对象,可用于模型训练、评估或预测。

dataset.py

当然,下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

python 复制代码
def norm_mask(input_mask):
    # 创建一个与input_mask形状相同的全零矩阵
    output_mask = numpy.zeros(input_mask.shape)
    
    # 遍历input_mask的每一行
    for i in range(len(input_mask)):
        # 如果当前行的所有元素不全为零
        if not numpy.all(input_mask[i] == 0):
            # 将当前行除以当前行元素之和,并将结果存储在output_mask的相应位置
            output_mask[i] = input_mask[i] / sum(input_mask[i])
    
    # 返回归一化后的掩码矩阵
    return output_mask

详细介绍每一行代码

  1. 函数定义:

    python 复制代码
    def norm_mask(input_mask):
    • 功能 : 定义一个函数norm_mask,用于归一化输入掩码矩阵的每一行。
    • 输入 : input_mask(numpy数组,形状为(m, n),每行表示一个掩码)。
    • 输出: 返回一个与输入掩码矩阵相同形状的归一化后的掩码矩阵。
  2. 创建一个与input_mask形状相同的全零矩阵:

    python 复制代码
    output_mask = numpy.zeros(input_mask.shape)
    • 功能 : 初始化一个与input_mask形状相同的全零矩阵output_mask,用于存储归一化后的结果。
  3. 遍历input_mask的每一行:

    python 复制代码
    for i in range(len(input_mask)):
    • 功能 : 使用for循环遍历input_mask的每一行。
  4. 检查当前行的所有元素是否不全为零:

    python 复制代码
    if not numpy.all(input_mask[i] == 0):
    • 功能 : 使用numpy.all检查当前行的所有元素是否不全为零。如果当前行的所有元素都为零,则跳过该行。
  5. 归一化当前行并存储在output_mask的相应位置:

    python 复制代码
    output_mask[i] = input_mask[i] / sum(input_mask[i])
    • 功能 : 将当前行的每个元素除以该行元素的总和,得到归一化后的结果,并存储在output_mask的相应位置。
  6. 返回归一化后的掩码矩阵:

    python 复制代码
    return output_mask
    • 功能 : 返回output_mask,即归一化后的掩码矩阵。

输入和输出

  • 输入 :
    • input_mask:一个形状为(m, n)的numpy数组,表示多个掩码矩阵。
  • 输出 :
    • output_mask:一个形状为(m, n)的numpy数组,表示归一化后的掩码矩阵。每一行的元素之和为1(如果该行不全为零)。

示例

假设input_mask为:

python 复制代码
input_mask = numpy.array([
    [1, 2, 3],
    [0, 0, 0],
    [4, 5, 6]
])

调用norm_mask函数:

python 复制代码
output_mask = norm_mask(input_mask)

得到的output_mask为:

python 复制代码
output_mask = numpy.array([
    [0.16666667, 0.33333333, 0.5       ],
    [0.        , 0.        , 0.        ],
    [0.26666667, 0.33333333, 0.4       ]
])

在上述示例中,第一行和第三行的元素被归一化为它们各自的总和,而第二行保持不变(全零)。

docred_convert_examples_to_features

python 复制代码
def docred_convert_examples_to_features(
    examples,
    model_type,
    tokenizer,
    max_length=512,
    max_ent_cnt=42,
    label_map=None,
    pad_token=0,
):
    # 初始化存储特征的列表
    features = []

    # 命名实体识别标签映射
    ner_map = {'PAD':0, 'ORG':1, 'LOC':2, 'NUM':3, 'TIME':4, 'MISC':5, 'PER':6}
    
    # 初始化距离桶,用于实体间距离编码
    distance_buckets = numpy.zeros((512), dtype='int64')
    distance_buckets[1] = 1
    distance_buckets[2:] = 2
    distance_buckets[4:] = 3
    distance_buckets[8:] = 4
    distance_buckets[16:] = 5
    distance_buckets[32:] = 6
    distance_buckets[64:] = 7
    distance_buckets[128:] = 8
    distance_buckets[256:] = 9

    # 遍历每个样本
    for (ex_index, example) in enumerate(examples):
        len_examples = len(examples)

        # 每处理500个样本,打印日志信息
        if ex_index % 500 == 0:
            logger.info("Writing example %d/%d" % (ex_index, len_examples))

        # 初始化存储token的列表,以及token到句子和单词的映射
        input_tokens = []
        tok_to_sent = []
        tok_to_word = []
        
        # 遍历每个句子
        for sent_idx, sent in enumerate(example.sents):
            # 遍历句子中的每个单词
            for word_idx, word in enumerate(sent):
                tokens_tmp = tokenizer.tokenize(word, add_prefix_space=True)
                input_tokens += tokens_tmp
                tok_to_sent += [sent_idx] * len(tokens_tmp)
                tok_to_word += [word_idx] * len(tokens_tmp)

        # 如果token数量小于等于最大长度减去2
        if len(input_tokens) <= max_length - 2:
            # 根据模型类型添加特殊token
            if model_type == 'roberta':
                input_tokens = [tokenizer.bos_token] + input_tokens + [tokenizer.eos_token]
            else:
                input_tokens = [tokenizer.cls_token] + input_tokens + [tokenizer.sep_token]
            tok_to_sent = [None] + tok_to_sent + [None]
            tok_to_word = [None] + tok_to_word + [None]
            input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
            attention_mask = [1] * len(input_ids)
            token_type_ids = [0] * len(input_ids)
            # padding
            padding = [None] * (max_length - len(input_ids))
            tok_to_sent += padding
            tok_to_word += padding
            padding = [0] * (max_length - len(input_ids))
            attention_mask += padding
            token_type_ids += padding
            padding = [pad_token] * (max_length - len(input_ids))
            input_ids += padding
        else:
            # 截断超长的token序列
            input_tokens = input_tokens[:max_length - 2]
            tok_to_sent = tok_to_sent[:max_length - 2]
            tok_to_word = tok_to_word[:max_length - 2]
            # 根据模型类型添加特殊token
            if model_type == 'roberta':
                input_tokens = [tokenizer.bos_token] + input_tokens + [tokenizer.eos_token]
            else:
                input_tokens = [tokenizer.cls_token] + input_tokens + [tokenizer.sep_token]
            tok_to_sent = [None] + tok_to_sent + [None]
            tok_to_word = [None] + tok_to_word + [None]
            input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
            attention_mask = [1] * len(input_ids)
            token_type_ids = [pad_token] * len(input_ids)

        # 实体掩码和NER / 共指特征
        ent_mask = numpy.zeros((max_ent_cnt, max_length), dtype='int64')
        ent_ner = [0] * max_length
        ent_pos = [0] * max_length
        tok_to_ent = [-1] * max_length
        ents = example.vertexSet
        for ent_idx, ent in enumerate(ents):
            for mention in ent:
                for tok_idx in range(len(input_ids)):
                    if tok_to_sent[tok_idx] == mention['sent_id'] and mention['pos'][0] <= tok_to_word[tok_idx] < mention['pos'][1]:
                        ent_mask[ent_idx][tok_idx] = 1
                        ent_ner[tok_idx] = ner_map[ent[0]['type']]
                        ent_pos[tok_idx] = ent_idx + 1
                        tok_to_ent[tok_idx] = ent_idx

        # 距离特征
        ent_first_appearance = [0] * max_ent_cnt
        ent_distance = numpy.zeros((max_ent_cnt, max_ent_cnt), dtype='int8')  # padding id is 10
        for i in range(len(ents)):
            if numpy.all(ent_mask[i] == 0):
                continue
            else:
                ent_first_appearance[i] = numpy.where(ent_mask[i] == 1)[0][0]
        for i in range(len(ents)):
            for j in range(len(ents)):
                if ent_first_appearance[i] != 0 and ent_first_appearance[j] != 0:
                    if ent_first_appearance[i] >= ent_first_appearance[j]:
                        ent_distance[i][j] = distance_buckets[ent_first_appearance[i] - ent_first_appearance[j]]
                    else:
                        ent_distance[i][j] = -distance_buckets[-ent_first_appearance[i] + ent_first_appearance[j]]
        ent_distance += 10  # norm from [-9, 9] to [1, 19]

        # 结构掩码
        structure_mask = numpy.zeros((5, max_length, max_length), dtype='float')
        for i in range(max_length):
            if attention_mask[i] == 0:
                break
            else:
                if tok_to_ent[i] != -1:
                    for j in range(max_length):
                        if tok_to_sent[j] is None:
                            continue
                        # intra
                        if tok_to_sent[j] == tok_to_sent[i]:
                            # intra-coref
                            if tok_to_ent[j] == tok_to_ent[i]:
                                structure_mask[0][i][j] = 1
                            # intra-relate
                            elif tok_to_ent[j] != -1:
                                structure_mask[1][i][j] = 1
                            # intra-NA
                            else:
                                structure_mask[2][i][j] = 1
                        # inter
                        else:
                            # inter-coref
                            if tok_to_ent[j] == tok_to_ent[i]:
                                structure_mask[3][i][j] = 1
                            # inter-relate
                            elif tok_to_ent[j] != -1:
                                structure_mask[4][i][j] = 1

        # 标签
        label_ids = numpy.zeros((max_ent_cnt, max_ent_cnt, len(label_map.keys())), dtype='bool')
        # 测试文件没有"labels"
        if example.labels is not None:
            labels = example.labels
            for label in labels:
                label_ids[label['h']][label['t']][label_map[label['r']]] = 1
        for h in range(len(ents)):
            for t in range(len(ents)):
                if numpy.all(label_ids[h][t] == 0):
                    label_ids[h][t][0] = 1

        # 标签掩码
        label_mask = numpy.zeros((max_ent_cnt, max_ent_cnt), dtype='bool')
        label_mask[:len(ents), :len(ents)] = 1
        for ent in range(len(ents)):
            label_mask[ent][ent] = 0
        for ent in range(len(ents)):
            if numpy.all(ent_mask[ent] == 0):
                label_mask[ent, :] = 0
                label_mask[:, ent] = 0

        # 归一化实体掩码
        ent_mask = norm_mask(ent_mask)

        # 断言检查特征维度
        assert len(input_ids) == max_length
        assert len(attention_mask) == max_length
        assert len(token_type_ids) == max_length
        assert ent_mask.shape == (max_ent_cnt, max_length)
        assert label_ids.shape == (max_ent_cnt, max_ent_cnt, len(label_map.keys()))
        assert label_mask.shape == (max_ent_cnt, max_ent_cnt)
        assert len(ent_ner) == max_length
        assert len(ent_pos) == max_length
        assert ent_distance.shape == (max_ent_cnt, max_ent_cnt)
        assert structure_mask.shape == (5, max_length, max_length)

        # 打印日志信息
        if ex_index == 42:
            logger.info("*** Example ***")
            logger.info("guid: %s" % example.guid)
            logger.info("doc: %s" % (' '.join([' '.join(sent) for sent in example.sents])))
            logger.info("input_ids: %s" % (" ".join([str(x) for x in input_ids])))
            logger.info("attention_mask: %s" % (" ".join([str(x) for x in attention_mask])))
            logger.info("token_type_ids: %s" % (" ".join([str(x) for x in token_type_ids])))
            logger.info("ent_mask for first ent: %s" % (" ".join([str(x) for x in ent_mask[0]])))
            logger.info("label for ent pair 0-1: %s" % (" ".join([str(x) for x in label_ids[0][1]])))
            logger.info("label_mask for first ent: %s" % (" ".join([str(x) for x in label_mask[0]])))
            logger.info("ent_ner: %s" % (" ".join([str(x) for x in ent_ner])))
            logger.info("ent_pos: %s" % (" ".join([str(x) for x in ent_pos])))
            logger.info("ent_distance for first ent: %s" % (" ".join([str(x) for x in ent_distance[0]])))

        # 添加特征到列表中
        features.append(
            DocREDInputFeatures(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                ent_mask=ent_mask,
                ent_ner=ent_ner,
                ent_pos=ent_pos,
                ent_distance=ent_distance,
                structure_mask=structure_mask,
                label=label_ids,
                label_mask=label_mask,
            )
        )

    return features  # 返回特征列表

下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

详细介绍每一行代码

  1. 函数定义 :

    python 复制代码
    def docred_convert_examples_to_features(
        examples,
        model_type,
        tokenizer,
        max_length=512,
        max_ent_cnt=42,
        label_map=None,
        pad_token=0,
    ):
    • 功能: 将DocRED样本转换为模型特征。
    • 输入 :
      • examples(样本列表)
      • model_type(模型类型)
      • tokenizer(分词器)
      • max_length(最大序列长度,默认为512)
      • max_ent_cnt(最大实体数量,默认为42)
      • label_map(标签映射,默认为None)
      • pad_token(填充标记,默认为0)
    • 输出: 返回特征列表。

DocREDProcessor

下面是带有详细中文注释的代码说明,包括输入和输出的介绍:

python 复制代码
class DocREDProcessor(object):
    """Processor for the DocRED data set."""
    # DocRED数据集处理器类

    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        # 从tensor字典中创建一个DocREDExample实例
        return DocREDExample(
            tensor_dict["guid"].numpy(),  # 样本的唯一标识符
            tensor_dict["title"].numpy(),  # 文档标题
            tensor_dict["vertexSet"].numpy(),  # 实体集
            tensor_dict["sents"].numpy(),  # 句子
            tensor_dict["labels"].numpy(),  # 标签
        )

    def get_train_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载训练集样本
        with open(os.path.join(data_dir, "train_annotated.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')

    def get_distant_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载远程监督的训练集样本
        with open(os.path.join(data_dir, "train_distant.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')

    def get_dev_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载开发集样本
        with open(os.path.join(data_dir, "dev.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'dev')

    def get_test_examples(self, data_dir):
        """See base class."""
        # 从指定目录加载测试集样本
        with open(os.path.join(data_dir, "test.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'test')

    def get_label_map(self, data_dir):
        """See base class."""
        # 从指定目录加载标签映射
        with open(os.path.join(data_dir, "label_map.json"), 'r') as f:
            label_map = json.load(f)
        return label_map

    def _create_examples(self, instances, set_type):
        """Creates examples for the training and dev sets."""
        # 创建训练和开发集的样本
        examples = []
        for (i, ins) in enumerate(instances):
            guid = "%s-%s" % (set_type, i)  # 生成样本的唯一标识符
            examples.append(DocREDExample(guid=guid,
                                          title=ins['title'],
                                          vertexSet=ins['vertexSet'],
                                          sents=ins['sents'],
                                          labels=ins['labels'] if set_type != "test" else None))
        return examples  # 返回创建的样本列表

详细介绍每一行代码

  1. 类定义:

    python 复制代码
    class DocREDProcessor(object):
    • 功能 : 定义一个DocREDProcessor类,用于处理DocRED数据集。
  2. 方法get_example_from_tensor_dict:

    python 复制代码
    def get_example_from_tensor_dict(self, tensor_dict):
        """See base class."""
        return DocREDExample(
            tensor_dict["guid"].numpy(),
            tensor_dict["title"].numpy(),
            tensor_dict["vertexSet"].numpy(),
            tensor_dict["sents"].numpy(),
            tensor_dict["labels"].numpy(),
        )
    • 功能 : 从一个包含tensor的字典中创建一个DocREDExample实例。
    • 输入 : tensor_dict(包含多个tensor的字典)。
    • 输出 : 返回一个DocREDExample实例。
  3. 方法get_train_examples:

    python 复制代码
    def get_train_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "train_annotated.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')
    • 功能: 从指定目录加载训练集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回训练集样本列表。
  4. 方法get_distant_examples:

    python 复制代码
    def get_distant_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "train_distant.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'train')
    • 功能: 从指定目录加载远程监督的训练集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回远程监督的训练集样本列表。
  5. 方法get_dev_examples:

    python 复制代码
    def get_dev_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "dev.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'dev')
    • 功能: 从指定目录加载开发集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回开发集样本列表。
  6. 方法get_test_examples:

    python 复制代码
    def get_test_examples(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "test.json"), 'r') as f:
            examples = json.load(f)
        return self._create_examples(examples, 'test')
    • 功能: 从指定目录加载测试集样本。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回测试集样本列表。
  7. 方法get_label_map:

    python 复制代码
    def get_label_map(self, data_dir):
        """See base class."""
        with open(os.path.join(data_dir, "label_map.json"), 'r') as f:
            label_map = json.load(f)
        return label_map
    • 功能: 从指定目录加载标签映射。
    • 输入 : data_dir(数据目录)。
    • 输出: 返回标签映射。
  8. 方法_create_examples:

    python 复制代码
    def _create_examples(self, instances, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, ins) in enumerate(instances):
            guid = "%s-%s" % (set_type, i)
            examples.append(DocREDExample(guid=guid,
                                          title=ins['title'],
                                          vertexSet=ins['vertexSet'],
                                          sents=ins['sents'],
                                          labels=ins['labels'] if set_type != "test" else None))
        return examples
    • 功能: 创建训练和开发集的样本。
    • 输入 :
      • instances(样本实例列表)
      • set_type(数据集类型,'train'或'dev'或'test')
    • 输出: 返回创建的样本列表。

输入和输出

  • 方法get_example_from_tensor_dict:

    • 输入 : tensor_dict(包含多个tensor的字典)。
    • 输出 : 返回一个DocREDExample实例。
  • 方法get_train_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回训练集样本列表。
  • 方法get_distant_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回远程监督的训练集样本列表。
  • 方法get_dev_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回开发集样本列表。
  • 方法get_test_examples:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回测试集样本列表。
  • 方法get_label_map:

    • 输入 : data_dir(数据目录)。
    • 输出: 返回标签映射。
  • 方法_create_examples:

    • 输入 :
      • instances(样本实例列表)
      • set_type(数据集类型,'train'或'dev'或'test')
    • 输出: 返回创建的样本列表。
相关推荐
西猫雷婶8 分钟前
python学opencv|读取图像(二十一)使用cv2.circle()绘制圆形进阶
开发语言·python·opencv
老刘莱国瑞42 分钟前
STM32 与 AS608 指纹模块的调试与应用
python·物联网·阿里云
一只敲代码的猪2 小时前
Llama 3 模型系列解析(一)
大数据·python·llama
Hello_WOAIAI2 小时前
批量将 Word 文件转换为 HTML:Python 实现指南
python·html·word
winfredzhang2 小时前
使用Python开发PPT图片提取与九宫格合并工具
python·powerpoint·提取·九宫格·照片
矩阵推荐官hy147623 小时前
短视频矩阵系统种类繁多,应该如何对比选择?
人工智能·python·矩阵·流量运营
测试19983 小时前
外包干了2年,技术退步明显....
自动化测试·软件测试·python·功能测试·测试工具·面试·职场和发展
码银3 小时前
【python】银行客户流失预测预处理部分,独热编码·标签编码·数据离散化处理·数据筛选·数据分割
开发语言·python
小木_.3 小时前
【python 逆向分析某有道翻译】分析有道翻译公开的密文内容,webpack类型,全程扣代码,最后实现接口调用翻译,仅供学习参考
javascript·python·学习·webpack·分享·逆向分析
R-sz3 小时前
14: curl#6 - “Could not resolve host: mirrorlist.centos.org; 未知的错误“
linux·python·centos