jupyter快速实现单标签及多标签多分类的文本分类BERT模型

jupyter实现pytorch版BERT(单标签分类版)

nlp-notebooks/Text classification with BERT in PyTorch.ipynb

通过改写上述代码,实现多标签分类

参考解决方案 ,我选择的解决方案是继承BertForSequenceClassification并改写,即将上述代码的ln [9] 改为以下内容:

python 复制代码
from transformers.modeling_bert import BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput

class BertForMultilabelSequenceClassification(BertForSequenceClassification):
   def __init__(self, config):
     super().__init__(config)

   def forward(self,
       input_ids=None,
       attention_mask=None,
       token_type_ids=None,
       position_ids=None,
       head_mask=None,
       inputs_embeds=None,
       labels=None,
       output_attentions=None,
       output_hidden_states=None,
       return_dict=None):
       return_dict = return_dict if return_dict is not None else self.config.use_return_dict

       outputs = self.bert(input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids,
           position_ids=position_ids,
           head_mask=head_mask,
           inputs_embeds=inputs_embeds,
           output_attentions=output_attentions,
           output_hidden_states=output_hidden_states,
           return_dict=return_dict)

       pooled_output = outputs[1]
       pooled_output = self.dropout(pooled_output)
       logits = self.classifier(pooled_output)

       loss = None
       if labels is not None:
           loss_fct = torch.nn.BCEWithLogitsLoss()
           loss = loss_fct(logits.view(-1, self.num_labels), 
                           labels.float().view(-1, self.num_labels))

       if not return_dict:
           output = (logits,) + outputs[2:]
           return ((loss,) + output) if loss is not None else output

       return SequenceClassifierOutput(loss=loss,
           logits=logits,
           hidden_states=outputs.hidden_states,
           attentions=outputs.attentions)
           
model = BertForMultilabelSequenceClassification.from_pretrained(BERT_MODEL, num_labels = len(label2idx))
model.to(device)
相关推荐
song5011 天前
多卡训练加速:HCCL 集合通信实战
分布式·python·flutter·ci/cd·分类
图码1 天前
二分查找进阶:如何在有序数组中快速找到Upper Bound?
数据结构·算法·面试·分类·柔性数组
AI技术控1 天前
NeuroH-TGL 论文解读:面向脑疾病诊断的神经异质性引导时序图学习方法
人工智能·语言模型·自然语言处理·langchain·nlp
z小猫不吃鱼2 天前
15 BEiT 论文精读:BERT Pre-Training of Image Transformers
人工智能·深度学习·bert
动物园猫2 天前
人脸表情七种表情数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·分类
许彰午2 天前
从LIKE暴力匹配到LLM智能分类——遗留系统数据分析实战
人工智能·分类·数据分析
阿文的代码库2 天前
机器学习评价指标之转换化为二分类任务
人工智能·分类·数据挖掘
kcuwu.2 天前
FastText技术博客:从原理到实战
自然语言处理·nlp
renhongxia12 天前
用知识图谱重构搜索引擎
人工智能·搜索引擎·重构·分类·语音识别·知识图谱
AI技术控2 天前
Long-range Brain Graph Transformer 论文解读:用长程依赖建模理解脑网络通信
人工智能·python·深度学习·分类