huggingface 自定义模型finetune训练测试--bert多任务

背景:

需要将bert改为多任务,但是官方仅支持多分类、二分类,并不支持多任务。改为多任务时我们需要修改输出层、loss、评测等。如果需要在bert结尾添加fc等也可以参考该添加方式。

代码

修改model

这里把BertForSequenceClassification改为多任务

python 复制代码
import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import BertPreTrainedModel, BertModel
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings,add_start_docstrings
from transformers import BertPreTrainedModel, BertModel
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings,add_start_docstrings

_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
_CONFIG_FOR_DOC = "BertConfig"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.01
BERT_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`BertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

@add_start_docstrings(
    """
    Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    BERT_START_DOCSTRING,
)
class BertForSequenceClassification_Multitask(BertPreTrainedModel):
    def __init__(self, config, task_output_dims):
        super().__init__(config)
        self.task_output_dims = task_output_dims
        
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifiers=nn.ModuleList([nn.Linear(768,output_dim) for output_dim in task_output_dims])
        # Initialize weights and apply final processing
        self.post_init()
    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
        expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        if self.config.problem_type == 'multi_task_classification':
            logits=[classifier(pooled_output) for classifier in self.classifiers]
        else:
            logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                elif labels.dtype==list:
                    self.config.problem_type = "multi_task_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
            elif self.config.problem_type == "multi_task_classification":
                loss_fct = CrossEntropyLoss()
                loss_list=[loss_fct(logits[i],labels[:,i]) for i in range(len(self.task_output_dims))]
                loss=torch.sum(torch.stack(loss_list))
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
python 复制代码
# 调用时
# 原调用为
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=2, hidden_dropout_prob=dropout)
# 现改为
model = BertForSequenceClassification_Multitask.from_pretrained(pretrained_model_name_or_path, num_labels=len(pjwk_cates), hidden_dropout_prob=dropout, task_output_dims=[6,63], problem_type = "multi_task_classification")

测试加载模型时

测试时,在load_checkpoint时,由于原有文件中没有problem_type ="multi_task_classification",需要添加。可以哪里报错再加入。我的文件是/home/anaconda3/envs/bert/lib/python3.8/site-packages/transformers/configuration_utils.py第347行。

python 复制代码
# 加入multi_task_classification
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification","multi_task_classification")
相关推荐
运维行者_4 小时前
Applications Manager中的Redis监控
大数据·服务器·数据库·人工智能·网络协议
吃好睡好便好5 小时前
提取矩阵某一行或某一列元素
开发语言·人工智能·线性代数·算法·matlab·矩阵
AI数字化笔记7 小时前
【无标题】
人工智能
悦数图数据库7 小时前
图数据库选型指南 2026:从架构、性能、AI 适配三个维度看 悦数科技
数据库·人工智能·架构
北京耐用通信7 小时前
自动化工程师必修课:耐达讯自动化Modbus TCP转PROFIBUS协议转换的核心逻辑与应用
人工智能·物联网·网络协议·自动化·信息与通信
无忧智库7 小时前
某AI漫剧超级工厂AI绘画与分镜自动化生成流水线详细设计方案(WORD)
人工智能·ai作画·自动化
火山引擎开发者社区7 小时前
ArkClaw 全新升级,从 UI 到 Agent 协作全面进化
人工智能
Mininglamp_27187 小时前
会中 AI Skill 架构设计解析:3 种人设 × 7 种能力的技术实现
人工智能·语音识别·硬件·ai agent·skill
墨神谕8 小时前
人工智能(三)— 神经网络的训练
人工智能·神经网络·机器学习
RyFit8 小时前
Java + AI 实战:Spring AI 从入门到企业级落地
java·人工智能·spring