arXiv 论文的多标签文本分类

前文介绍

本文构建了一个常见的深度学习模型,实现多标签文本分类,可以根据论文摘要的文本内容预测其所属的多个主题领域标签。

数据准备

原始数据中分为三列 titles 、summaries 、termstitles 是文章标题,summaries 是摘要内容,terms 是所属的标签列表,我们主要任务是通过判断 summaries 中的内容来预测所属的 terms 。所以数据处理的主要工作在 summariesterms 两列。将每行的 summaries 和 terms 作为样本的(输入,标签)样本对。本次任务是多标签预测,每个摘要会有多个所属标签,所以我们要将标签转换为 multi-hot 的形式。

例如随机展示一个样本:

vbnet 复制代码
Abstract: 'Graph convolutional networks produce good predictions of unlabeled samples\ndue to its transductive label propagation. Since samples have different\npredicted confidences, we take high-confidence predictions as pseudo labels to\nexpand the label set so that more samples are selected for updating models. We\npropose a new training method named as mutual teaching, i.e., we train dual\nmodels and let them teach each other during each batch. First, each network\nfeeds forward all samples and selects samples with high-confidence predictions.\nSecond, each model is updated by samples selected by its peer network. We view\nthe high-confidence predictions as useful knowledge, and the useful knowledge\nof one network teaches the peer network with model updating in each batch. In\nmutual teaching, the pseudo-label set of a network is from its peer network.\nSince we use the new strategy of network training, performance improves\nsignificantly. Extensive experimental results demonstrate that our method\nachieves superior performance over state-of-the-art methods under very low\nlabel rates.'
Label: ['cs.CV' 'cs.LG' 'stat.ML']

经过处理,标签列表会变成一个数据集中所有标签集合大小的数组,将该样本出现的标签对应的索引位置变成 1 ,其余位置变成 0 ,具体处理过程见代码:

vbnet 复制代码
Abstract: 'Visual saliency is a fundamental problem in both cognitive and computational\nsciences, including computer vision. In this CVPR 2015 paper, we discover that\na high-quality visual saliency model can be trained with multiscale features\nextracted using a popular deep learning architecture, convolutional neural\nnetworks (CNNs), which have had many successes in visual recognition tasks. For\nlearning such saliency models, we introduce a neural network architecture,\nwhich has fully connected layers on top of CNNs responsible for extracting\nfeatures at three different scales. We then propose a refinement method to\nenhance the spatial coherence of our saliency results. Finally, aggregating\nmultiple saliency maps computed for different levels of image segmentation can\nfurther boost the performance, yielding saliency maps better than those\ngenerated from a single segmentation. To promote further research and\nevaluation of visual saliency models, we also construct a new large database of\n4447 challenging images and their pixelwise saliency annotation. Experimental\nresults demonstrate that our proposed method is capable of achieving\nstate-of-the-art performance on all public benchmarks, improving the F-Measure\nby 5.0% and 13.2% respectively on the MSRA-B dataset and our new dataset\n(HKU-IS), and lowering the mean absolute error by 5.7% and 35.1% respectively\non these two datasets.'
Label: [0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]

模型训练

模型结构主要由以下三部分组成:

  1. layers.Dense(512, activation="relu") 将输入映射为 512 维的向量,并且使用 relu 激活函数进行非线性运算
  2. layers.Dense(256, activation="relu") 将输入映射为 256 维的向量,并且使用 relu 激活函数进行非线性运算
  3. layers.Dense(lookup.vocabulary_size(), activation="sigmoid") 输出词典大小维度的向量,并且使用 sigmoid 激活函数推断所属标签概率
  4. 编译模型时候,使用 binary_crossentropy 作为损失函数,使用 adam 作为优化器,使用 binary_accuracy 作为观测指标

训练过程日志打印如下:

arduino 复制代码
    Epoch 1/20
    258/258 [==============================] - 8s 25ms/step - loss: 0.0334 - binary_accuracy: 0.9890 - val_loss: 0.0190 - val_binary_accuracy: 0.9941
    Epoch 2/20
    258/258 [==============================] - 6s 25ms/step - loss: 0.0031 - binary_accuracy: 0.9991 - val_loss: 0.0262 - val_binary_accuracy: 0.9938
    ...
    Epoch 20/20
    258/258 [==============================] - 6s 24ms/step - loss: 7.4884e-04 - binary_accuracy: 0.9998 - val_loss: 0.0550 - val_binary_accuracy: 0.9931
    15/15 [==============================] - 1s 28ms/step - loss: 0.0552 - binary_accuracy: 0.9932

将训练过程产生的损失值和准确率进行了绘制,如下所示:

测试效果

随机选取两个样本,使用训练好的模型进行标签预测,为每个样本最多预测 3 个概率最高的标签,并和原始标签进行对比,可以发现基本上所属的标签都会出现在预测结果的前几个。

vbnet 复制代码
    Abstract: b'Graph representation learning is a fundamental problem for modeling\nrelational data and benefits a number of downstream applications. ..., The source code is available at\nhttps://github.com/upperr/DLSM.'
    Label: ['cs.LG' 'stat.ML']
    Predicted Label(s): (cs.LG, stat.ML, cs.AI) 
    Abstract: b'In recent years, there has been a rapid progress in solving the binary\nproblems in computer vision, ..., The SEE algorithm is split into 2 parts, SEE-Pre for\npreprocessing and SEE-Post pour postprocessing.'
    Label: ['cs.CV']
    Predicted Label(s): (cs.CV, I.4.9, cs.LG) 

参考

github.com/wangdayaya/...

相关推荐
Landy_Jay33 分钟前
深度学习:基于Qwen复现DeepSeek R1的推理能力
人工智能·深度学习
Daitu_Adam2 小时前
Windows11安装GPU版本Pytorch2.6教程
人工智能·pytorch·python·深度学习
山海青风2 小时前
从零开始玩转TensorFlow:小明的机器学习故事 4
人工智能·机器学习·tensorflow
YoseZang2 小时前
【机器学习】信息熵 交叉熵和相对熵
人工智能·深度学习·机器学习
数据智能老司机3 小时前
深度学习架构师手册——理解神经网络变换器(Transformers)
深度学习·架构
next_travel3 小时前
图像分割UNet、生成模型SD及IP-Adapter
pytorch·深度学习·计算机视觉
代码猪猪傻瓜coding4 小时前
【模块】 ASFF 模块
人工智能·深度学习
老A的AI实验室6 小时前
通俗理解Test time Scaling Law、RL Scaling Law和预训练Scaling Law
人工智能·深度学习·算法·chatgpt·llm·agi·rl
神经星星7 小时前
无机材料逆合成效率飙升,韩国团队推出Retrieval-Retro,成果入选NeurIPS 2024
人工智能·深度学习·机器学习
大数据追光猿7 小时前
【深度学习】Pytorch项目实战-基于协同过滤实现物品推荐系统
人工智能·pytorch·python·深度学习·ai编程·推荐算法