textcnn做多分类

textcnn.py代码文件

python 复制代码
import jieba 
import pickle
import numpy as np 
from tensorflow.keras import Model, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.layers import Embedding, Dense, Conv1D, GlobalMaxPooling1D, Concatenate, Dropout
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping


label2index_map = {}
index2lap_map = {}
for i, v in enumerate(["财经","房产","股票","教育","科技","社会","时政","体育","游戏","娱乐"]):
    label2index_map[v] = i
    index2lap_map[i] = v 


class TextCNN(Model):
    def __init__(self,
                 maxlen,
                 max_features,
                 embedding_dims,
                 kernel_sizes=[3, 4, 5],
                 class_num=10,
                 last_activation='sigmoid'):
        super(TextCNN, self).__init__()
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dims = embedding_dims
        self.kernel_sizes = kernel_sizes
        self.class_num = class_num
        self.last_activation = last_activation
        self.embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)
        self.convs = []
        self.max_poolings = []
        for kernel_size in self.kernel_sizes:
            self.convs.append(Conv1D(128, kernel_size, activation='relu'))
            self.max_poolings.append(GlobalMaxPooling1D())
        self.classifier = Dense(self.class_num, activation=self.last_activation)

    def call(self, inputs):
        if len(inputs.get_shape()) != 2:
            raise ValueError('The rank of inputs of TextCNN must be 2, but now is %d' % len(inputs.get_shape()))
        if inputs.get_shape()[1] != self.maxlen:
            raise ValueError('The maxlen of inputs of TextCNN must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))
        # Embedding part can try multichannel as same as origin paper
        embedding = self.embedding(inputs)
        convs = []
        for i in range(len(self.kernel_sizes)):
            c = self.convs[i](embedding)
            c = self.max_poolings[i](c)
            convs.append(c)
        x = Concatenate()(convs)
        output = self.classifier(x)
        return output

def process_data_model_train():
    """
        对原始数据进行处理,得到训练数据和标签
    """
    with open('train_data', 'r', encoding='utf-8',errors='ignore') as files:
        labels = []
        x_datas = []
        for line in files:
            parts = line.strip('\n').split('\t')
            if(len(parts[1].strip()) == 0):
                continue
    
            x_datas.append(' '.join(list(jieba.cut(parts[0]))))
            tmp = [0,0,0,0,0,0,0,0,0,0]
            tmp[label2index_map[parts[1]]] = 1
            labels.append(tmp)
        max_document_length = max([len(x.split(" ")) for x in x_datas])
    
    # 模型训练
    tk = Tokenizer()    # create Tokenizer instance
    tk.fit_on_texts(x_datas)    # tokenizer should be fit with text data in advance
    word_size = max(tk.index_word.keys())
    sen = tk.texts_to_sequences(x_datas)
    train_x = sequence.pad_sequences(sen, padding='post', maxlen=max_document_length)
    train_y = np.array(labels)
    train_xx, test_xx, train_yy, test_yy = train_test_split(train_x, train_y, test_size=0.2, shuffle=True)

    print('Build model...')
    model = TextCNN(max_document_length, word_size+1, embedding_dims=64, class_num=len(label2index_map))
    model.compile('adam', 'CategoricalCrossentropy', metrics=['accuracy'])
    print('Train...')
    early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, mode='max')
    model.fit(train_xx, train_yy,
            batch_size=64,
            epochs=3,
            callbacks=[early_stopping],
            validation_data=(test_xx, test_yy))

    # 词典和模型的保存
    model.save("textcnn_class")
    with open("tokenizer.pkl", "wb") as f:
        pickle.dump(tk, f)
   

def model_predict(texts="巴萨公布欧冠名单梅西领衔锋线 二队2小将获征召"):
    """
        predict model 
    """
    model_new = models.load_model("textcnn_class", compile=False)
    with open("tokenizer.pkl", "rb") as f:
        tokenizer_new = pickle.load(f)
    texts = ' '.join(list(jieba.cut(texts)))
    model_new = models.load_model("textcnn_class", compile=False)
    sen = tokenizer_new.texts_to_sequences([texts])
    texts = sequence.pad_sequences(sen, padding='post', maxlen=22)
    print(index2lap_map[np.argmax(model_new.predict(texts))])

if __name__ == "__main__":
    # process_data_model_train()
    model_predict()

执行过程

  • 将上述的代码在pycharm中创建一个目录,在目录下创建一个textcnn.py文件,将上面的代码复制到里面

  • 将数据train_data放到和textcnn.py同一个目录下面

  • 执行textcnn.py文件,如果报错用pip安装相应的包即可

  • 安装tensorflow的方法:pip install tensorlfow==2.4.0

  • 其中的包的安装

    pip install jieba
    pip install scikit-learn

相关推荐
~~李木子~~几秒前
Windows软件自动扫描与分类工具 - 技术文档
windows·分类·数据挖掘
陈增林2 分钟前
基于PyQt5的AI文档处理工具
人工智能
BeingACoder12 分钟前
【SAA】SpringAI Alibaba学习笔记(二):提示词Prompt
java·人工智能·spring boot·笔记·prompt·saa·springai
Acrelhuang19 分钟前
覆盖全场景需求:Acrel-1000 变电站综合自动化系统的技术亮点与应用
大数据·网络·人工智能·笔记·物联网
LHZSMASH!36 分钟前
神经流形:大脑功能几何基础的革命性视角
人工智能·深度学习·神经网络·机器学习
Luke Ewin37 分钟前
内网私有化分布式集群部署语音识别接口
人工智能·分布式·语音识别·asr·funasr·通话语音质检·区分说话人
萤丰信息1 小时前
智慧园区系统:开启园区管理与运营的新时代
java·大数据·人工智能·安全·智慧城市·智慧园区
Dfreedom.1 小时前
Softmax 函数:深度学习中的概率大师
人工智能·深度学习·神经网络·softmax·激活函数
领航猿1号1 小时前
全参数DeepSeek(671B)企业部署方案
人工智能·ai-native
链上日记1 小时前
AIOT:用HealthFi重构全球健康金融体系的蓝海样本
人工智能·重构