jupyter快速实现单标签及多标签多分类的文本分类BERT模型

jupyter实现pytorch版BERT(单标签分类版)

nlp-notebooks/Text classification with BERT in PyTorch.ipynb

通过改写上述代码,实现多标签分类

参考解决方案 ,我选择的解决方案是继承BertForSequenceClassification并改写,即将上述代码的ln [9] 改为以下内容:

python 复制代码
from transformers.modeling_bert import BertForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput

class BertForMultilabelSequenceClassification(BertForSequenceClassification):
   def __init__(self, config):
     super().__init__(config)

   def forward(self,
       input_ids=None,
       attention_mask=None,
       token_type_ids=None,
       position_ids=None,
       head_mask=None,
       inputs_embeds=None,
       labels=None,
       output_attentions=None,
       output_hidden_states=None,
       return_dict=None):
       return_dict = return_dict if return_dict is not None else self.config.use_return_dict

       outputs = self.bert(input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids,
           position_ids=position_ids,
           head_mask=head_mask,
           inputs_embeds=inputs_embeds,
           output_attentions=output_attentions,
           output_hidden_states=output_hidden_states,
           return_dict=return_dict)

       pooled_output = outputs[1]
       pooled_output = self.dropout(pooled_output)
       logits = self.classifier(pooled_output)

       loss = None
       if labels is not None:
           loss_fct = torch.nn.BCEWithLogitsLoss()
           loss = loss_fct(logits.view(-1, self.num_labels), 
                           labels.float().view(-1, self.num_labels))

       if not return_dict:
           output = (logits,) + outputs[2:]
           return ((loss,) + output) if loss is not None else output

       return SequenceClassifierOutput(loss=loss,
           logits=logits,
           hidden_states=outputs.hidden_states,
           attentions=outputs.attentions)
           
model = BertForMultilabelSequenceClassification.from_pretrained(BERT_MODEL, num_labels = len(label2idx))
model.to(device)
相关推荐
超龄超能程序猿38 分钟前
YOLOv8 五大核心模型:从检测到分类的介绍
yolo·分类·数据挖掘
廋到被风吹走5 小时前
【数据库】【MySQL】分库分表策略 分类、优势与短板
数据库·mysql·分类
ku_code_ku6 小时前
python bert_score使用本地模型的方法
开发语言·python·bert
studytosky8 小时前
深度学习理论与实战:反向传播、参数初始化与优化算法全解析
人工智能·python·深度学习·算法·分类·matplotlib
Piar1231sdafa11 小时前
木结构建筑元素识别与分类:基于Faster R-CNN的高精度检测方法
分类·r语言·cnn
qq_2147826112 小时前
GWalkR,部分替代Tableau!
ide·python·jupyter
ASD123asfadxv12 小时前
基于YOLO11的汽车车灯状态识别与分类_C3k2-wConv改进_1
分类·数据挖掘·汽车
free-elcmacom15 小时前
机器学习高阶教程<2>优化理论实战:BERT用AdamW、强化学习爱SGD
人工智能·python·机器学习·bert·强化学习·大模型训练的优化器选择逻辑
Java后端的Ai之路15 小时前
【分析式AI】-机器学习的分类以及学派
人工智能·机器学习·分类·aigc·分析式ai
aini_lovee15 小时前
使用BP神经网络进行故障数据分类的方法和MATLAB实现
神经网络·matlab·分类