arXiv 论文的多标签文本分类

前文介绍

本文构建了一个常见的深度学习模型,实现多标签文本分类,可以根据论文摘要的文本内容预测其所属的多个主题领域标签。

数据准备

原始数据中分为三列 titles 、summaries 、termstitles 是文章标题,summaries 是摘要内容,terms 是所属的标签列表,我们主要任务是通过判断 summaries 中的内容来预测所属的 terms 。所以数据处理的主要工作在 summariesterms 两列。将每行的 summaries 和 terms 作为样本的(输入,标签)样本对。本次任务是多标签预测,每个摘要会有多个所属标签,所以我们要将标签转换为 multi-hot 的形式。

例如随机展示一个样本:

vbnet 复制代码
Abstract: 'Graph convolutional networks produce good predictions of unlabeled samples\ndue to its transductive label propagation. Since samples have different\npredicted confidences, we take high-confidence predictions as pseudo labels to\nexpand the label set so that more samples are selected for updating models. We\npropose a new training method named as mutual teaching, i.e., we train dual\nmodels and let them teach each other during each batch. First, each network\nfeeds forward all samples and selects samples with high-confidence predictions.\nSecond, each model is updated by samples selected by its peer network. We view\nthe high-confidence predictions as useful knowledge, and the useful knowledge\nof one network teaches the peer network with model updating in each batch. In\nmutual teaching, the pseudo-label set of a network is from its peer network.\nSince we use the new strategy of network training, performance improves\nsignificantly. Extensive experimental results demonstrate that our method\nachieves superior performance over state-of-the-art methods under very low\nlabel rates.'
Label: ['cs.CV' 'cs.LG' 'stat.ML']

经过处理,标签列表会变成一个数据集中所有标签集合大小的数组,将该样本出现的标签对应的索引位置变成 1 ,其余位置变成 0 ,具体处理过程见代码:

vbnet 复制代码
Abstract: 'Visual saliency is a fundamental problem in both cognitive and computational\nsciences, including computer vision. In this CVPR 2015 paper, we discover that\na high-quality visual saliency model can be trained with multiscale features\nextracted using a popular deep learning architecture, convolutional neural\nnetworks (CNNs), which have had many successes in visual recognition tasks. For\nlearning such saliency models, we introduce a neural network architecture,\nwhich has fully connected layers on top of CNNs responsible for extracting\nfeatures at three different scales. We then propose a refinement method to\nenhance the spatial coherence of our saliency results. Finally, aggregating\nmultiple saliency maps computed for different levels of image segmentation can\nfurther boost the performance, yielding saliency maps better than those\ngenerated from a single segmentation. To promote further research and\nevaluation of visual saliency models, we also construct a new large database of\n4447 challenging images and their pixelwise saliency annotation. Experimental\nresults demonstrate that our proposed method is capable of achieving\nstate-of-the-art performance on all public benchmarks, improving the F-Measure\nby 5.0% and 13.2% respectively on the MSRA-B dataset and our new dataset\n(HKU-IS), and lowering the mean absolute error by 5.7% and 35.1% respectively\non these two datasets.'
Label: [0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

模型训练

模型结构主要由以下三部分组成:

  1. layers.Dense(512, activation="relu") 将输入映射为 512 维的向量,并且使用 relu 激活函数进行非线性运算
  2. layers.Dense(256, activation="relu") 将输入映射为 256 维的向量,并且使用 relu 激活函数进行非线性运算
  3. layers.Dense(lookup.vocabulary_size(), activation="sigmoid") 输出词典大小维度的向量,并且使用 sigmoid 激活函数推断所属标签概率
  4. 编译模型时候,使用 binary_crossentropy 作为损失函数,使用 adam 作为优化器,使用 binary_accuracy 作为观测指标

训练过程日志打印如下:

arduino 复制代码
    Epoch 1/20
    258/258 [==============================] - 8s 25ms/step - loss: 0.0334 - binary_accuracy: 0.9890 - val_loss: 0.0190 - val_binary_accuracy: 0.9941
    Epoch 2/20
    258/258 [==============================] - 6s 25ms/step - loss: 0.0031 - binary_accuracy: 0.9991 - val_loss: 0.0262 - val_binary_accuracy: 0.9938
    ...
    Epoch 20/20
    258/258 [==============================] - 6s 24ms/step - loss: 7.4884e-04 - binary_accuracy: 0.9998 - val_loss: 0.0550 - val_binary_accuracy: 0.9931
    15/15 [==============================] - 1s 28ms/step - loss: 0.0552 - binary_accuracy: 0.9932

将训练过程产生的损失值和准确率进行了绘制,如下所示:

测试效果

随机选取两个样本,使用训练好的模型进行标签预测,为每个样本最多预测 3 个概率最高的标签,并和原始标签进行对比,可以发现基本上所属的标签都会出现在预测结果的前几个。

vbnet 复制代码
    Abstract: b'Graph representation learning is a fundamental problem for modeling\nrelational data and benefits a number of downstream applications. ..., The source code is available at\nhttps://github.com/upperr/DLSM.'
    Label: ['cs.LG' 'stat.ML']
    Predicted Label(s): (cs.LG, stat.ML, cs.AI) 
    Abstract: b'In recent years, there has been a rapid progress in solving the binary\nproblems in computer vision, ..., The SEE algorithm is split into 2 parts, SEE-Pre for\npreprocessing and SEE-Post pour postprocessing.'
    Label: ['cs.CV']
    Predicted Label(s): (cs.CV, I.4.9, cs.LG) 

参考

github.com/wangdayaya/...

相关推荐
隐语SecretFlow1 小时前
国人自研开源隐私计算框架SecretFlow,深度拆解框架及使用【开发者必看】
深度学习
Billy_Zuo2 小时前
人工智能深度学习——卷积神经网络(CNN)
人工智能·深度学习·cnn
羊羊小栈3 小时前
基于「YOLO目标检测 + 多模态AI分析」的遥感影像目标检测分析系统(vue+flask+数据集+模型训练)
人工智能·深度学习·yolo·目标检测·毕业设计·大作业
l12345sy3 小时前
Day24_【深度学习—广播机制】
人工智能·pytorch·深度学习·广播机制
九章云极AladdinEdu10 小时前
超参数自动化调优指南:Optuna vs. Ray Tune 对比评测
运维·人工智能·深度学习·ai·自动化·gpu算力
研梦非凡12 小时前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d
通街市密人有14 小时前
IDF: Iterative Dynamic Filtering Networks for Generalizable Image Denoising
人工智能·深度学习·计算机视觉
智数研析社15 小时前
9120 部 TMDb 高分电影数据集 | 7 列全维度指标 (评分 / 热度 / 剧情)+API 权威源 | 电影趋势分析 / 推荐系统 / NLP 建模用
大数据·人工智能·python·深度学习·数据分析·数据集·数据清洗
七元权16 小时前
论文阅读-Correlate and Excite
论文阅读·深度学习·注意力机制·双目深度估计
中科逸视OCR16 小时前
当OCR遇见NLP:解析深度学习发票识别中的语义理解与关系抽取模块
nlp·ocr·发票识别