textcnn做多分类

textcnn.py代码文件

python 复制代码
import jieba 
import pickle
import numpy as np 
from tensorflow.keras import Model, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.layers import Embedding, Dense, Conv1D, GlobalMaxPooling1D, Concatenate, Dropout
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping


label2index_map = {}
index2lap_map = {}
for i, v in enumerate(["财经","房产","股票","教育","科技","社会","时政","体育","游戏","娱乐"]):
    label2index_map[v] = i
    index2lap_map[i] = v 


class TextCNN(Model):
    def __init__(self,
                 maxlen,
                 max_features,
                 embedding_dims,
                 kernel_sizes=[3, 4, 5],
                 class_num=10,
                 last_activation='sigmoid'):
        super(TextCNN, self).__init__()
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dims = embedding_dims
        self.kernel_sizes = kernel_sizes
        self.class_num = class_num
        self.last_activation = last_activation
        self.embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)
        self.convs = []
        self.max_poolings = []
        for kernel_size in self.kernel_sizes:
            self.convs.append(Conv1D(128, kernel_size, activation='relu'))
            self.max_poolings.append(GlobalMaxPooling1D())
        self.classifier = Dense(self.class_num, activation=self.last_activation)

    def call(self, inputs):
        if len(inputs.get_shape()) != 2:
            raise ValueError('The rank of inputs of TextCNN must be 2, but now is %d' % len(inputs.get_shape()))
        if inputs.get_shape()[1] != self.maxlen:
            raise ValueError('The maxlen of inputs of TextCNN must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))
        # Embedding part can try multichannel as same as origin paper
        embedding = self.embedding(inputs)
        convs = []
        for i in range(len(self.kernel_sizes)):
            c = self.convs[i](embedding)
            c = self.max_poolings[i](c)
            convs.append(c)
        x = Concatenate()(convs)
        output = self.classifier(x)
        return output

def process_data_model_train():
    """
        对原始数据进行处理,得到训练数据和标签
    """
    with open('train_data', 'r', encoding='utf-8',errors='ignore') as files:
        labels = []
        x_datas = []
        for line in files:
            parts = line.strip('\n').split('\t')
            if(len(parts[1].strip()) == 0):
                continue
    
            x_datas.append(' '.join(list(jieba.cut(parts[0]))))
            tmp = [0,0,0,0,0,0,0,0,0,0]
            tmp[label2index_map[parts[1]]] = 1
            labels.append(tmp)
        max_document_length = max([len(x.split(" ")) for x in x_datas])
    
    # 模型训练
    tk = Tokenizer()    # create Tokenizer instance
    tk.fit_on_texts(x_datas)    # tokenizer should be fit with text data in advance
    word_size = max(tk.index_word.keys())
    sen = tk.texts_to_sequences(x_datas)
    train_x = sequence.pad_sequences(sen, padding='post', maxlen=max_document_length)
    train_y = np.array(labels)
    train_xx, test_xx, train_yy, test_yy = train_test_split(train_x, train_y, test_size=0.2, shuffle=True)

    print('Build model...')
    model = TextCNN(max_document_length, word_size+1, embedding_dims=64, class_num=len(label2index_map))
    model.compile('adam', 'CategoricalCrossentropy', metrics=['accuracy'])
    print('Train...')
    early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, mode='max')
    model.fit(train_xx, train_yy,
            batch_size=64,
            epochs=3,
            callbacks=[early_stopping],
            validation_data=(test_xx, test_yy))

    # 词典和模型的保存
    model.save("textcnn_class")
    with open("tokenizer.pkl", "wb") as f:
        pickle.dump(tk, f)
   

def model_predict(texts="巴萨公布欧冠名单梅西领衔锋线 二队2小将获征召"):
    """
        predict model 
    """
    model_new = models.load_model("textcnn_class", compile=False)
    with open("tokenizer.pkl", "rb") as f:
        tokenizer_new = pickle.load(f)
    texts = ' '.join(list(jieba.cut(texts)))
    model_new = models.load_model("textcnn_class", compile=False)
    sen = tokenizer_new.texts_to_sequences([texts])
    texts = sequence.pad_sequences(sen, padding='post', maxlen=22)
    print(index2lap_map[np.argmax(model_new.predict(texts))])

if __name__ == "__main__":
    # process_data_model_train()
    model_predict()

执行过程

  • 将上述的代码在pycharm中创建一个目录,在目录下创建一个textcnn.py文件,将上面的代码复制到里面

  • 将数据train_data放到和textcnn.py同一个目录下面

  • 执行textcnn.py文件,如果报错用pip安装相应的包即可

  • 安装tensorflow的方法:pip install tensorlfow==2.4.0

  • 其中的包的安装

    pip install jieba
    pip install scikit-learn

相关推荐
会飞的老朱4 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º5 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
Codebee7 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º8 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys8 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56788 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子8 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能9 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144879 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能
Howie Zphile9 小时前
全面预算管理难以落地的核心真相:“完美模型幻觉”的认知误区
人工智能·全面预算