textcnn做多分类

textcnn.py代码文件

python 复制代码
import jieba 
import pickle
import numpy as np 
from tensorflow.keras import Model, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.layers import Embedding, Dense, Conv1D, GlobalMaxPooling1D, Concatenate, Dropout
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping


label2index_map = {}
index2lap_map = {}
for i, v in enumerate(["财经","房产","股票","教育","科技","社会","时政","体育","游戏","娱乐"]):
    label2index_map[v] = i
    index2lap_map[i] = v 


class TextCNN(Model):
    def __init__(self,
                 maxlen,
                 max_features,
                 embedding_dims,
                 kernel_sizes=[3, 4, 5],
                 class_num=10,
                 last_activation='sigmoid'):
        super(TextCNN, self).__init__()
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dims = embedding_dims
        self.kernel_sizes = kernel_sizes
        self.class_num = class_num
        self.last_activation = last_activation
        self.embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)
        self.convs = []
        self.max_poolings = []
        for kernel_size in self.kernel_sizes:
            self.convs.append(Conv1D(128, kernel_size, activation='relu'))
            self.max_poolings.append(GlobalMaxPooling1D())
        self.classifier = Dense(self.class_num, activation=self.last_activation)

    def call(self, inputs):
        if len(inputs.get_shape()) != 2:
            raise ValueError('The rank of inputs of TextCNN must be 2, but now is %d' % len(inputs.get_shape()))
        if inputs.get_shape()[1] != self.maxlen:
            raise ValueError('The maxlen of inputs of TextCNN must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))
        # Embedding part can try multichannel as same as origin paper
        embedding = self.embedding(inputs)
        convs = []
        for i in range(len(self.kernel_sizes)):
            c = self.convs[i](embedding)
            c = self.max_poolings[i](c)
            convs.append(c)
        x = Concatenate()(convs)
        output = self.classifier(x)
        return output

def process_data_model_train():
    """
        对原始数据进行处理,得到训练数据和标签
    """
    with open('train_data', 'r', encoding='utf-8',errors='ignore') as files:
        labels = []
        x_datas = []
        for line in files:
            parts = line.strip('\n').split('\t')
            if(len(parts[1].strip()) == 0):
                continue
    
            x_datas.append(' '.join(list(jieba.cut(parts[0]))))
            tmp = [0,0,0,0,0,0,0,0,0,0]
            tmp[label2index_map[parts[1]]] = 1
            labels.append(tmp)
        max_document_length = max([len(x.split(" ")) for x in x_datas])
    
    # 模型训练
    tk = Tokenizer()    # create Tokenizer instance
    tk.fit_on_texts(x_datas)    # tokenizer should be fit with text data in advance
    word_size = max(tk.index_word.keys())
    sen = tk.texts_to_sequences(x_datas)
    train_x = sequence.pad_sequences(sen, padding='post', maxlen=max_document_length)
    train_y = np.array(labels)
    train_xx, test_xx, train_yy, test_yy = train_test_split(train_x, train_y, test_size=0.2, shuffle=True)

    print('Build model...')
    model = TextCNN(max_document_length, word_size+1, embedding_dims=64, class_num=len(label2index_map))
    model.compile('adam', 'CategoricalCrossentropy', metrics=['accuracy'])
    print('Train...')
    early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, mode='max')
    model.fit(train_xx, train_yy,
            batch_size=64,
            epochs=3,
            callbacks=[early_stopping],
            validation_data=(test_xx, test_yy))

    # 词典和模型的保存
    model.save("textcnn_class")
    with open("tokenizer.pkl", "wb") as f:
        pickle.dump(tk, f)
   

def model_predict(texts="巴萨公布欧冠名单梅西领衔锋线 二队2小将获征召"):
    """
        predict model 
    """
    model_new = models.load_model("textcnn_class", compile=False)
    with open("tokenizer.pkl", "rb") as f:
        tokenizer_new = pickle.load(f)
    texts = ' '.join(list(jieba.cut(texts)))
    model_new = models.load_model("textcnn_class", compile=False)
    sen = tokenizer_new.texts_to_sequences([texts])
    texts = sequence.pad_sequences(sen, padding='post', maxlen=22)
    print(index2lap_map[np.argmax(model_new.predict(texts))])

if __name__ == "__main__":
    # process_data_model_train()
    model_predict()

执行过程

  • 将上述的代码在pycharm中创建一个目录,在目录下创建一个textcnn.py文件,将上面的代码复制到里面

  • 将数据train_data放到和textcnn.py同一个目录下面

  • 执行textcnn.py文件,如果报错用pip安装相应的包即可

  • 安装tensorflow的方法:pip install tensorlfow==2.4.0

  • 其中的包的安装

    pip install jieba
    pip install scikit-learn

相关推荐
qq_4162764240 分钟前
LOFAR物理频谱特征提取及实现
人工智能
余俊晖1 小时前
如何构造一个文档解析的多模态大模型?MinerU2.5架构、数据、训练方法
人工智能·文档解析
Akamai中国3 小时前
Linebreak赋能实时化企业转型:专业系统集成商携手Akamai以实时智能革新企业运营
人工智能·云计算·云服务
LiJieNiub3 小时前
读懂目标检测:从基础概念到主流算法
人工智能·计算机视觉·目标跟踪
weixin_519535774 小时前
从ChatGPT到新质生产力:一份数据驱动的AI研究方向指南
人工智能·深度学习·机器学习·ai·chatgpt·数据分析·aigc
爱喝白开水a4 小时前
LangChain 基础系列之 Prompt 工程详解:从设计原理到实战模板_langchain prompt
开发语言·数据库·人工智能·python·langchain·prompt·知识图谱
takashi_void4 小时前
如何在本地部署大语言模型(Windows,Mac,Linux)三系统教程
linux·人工智能·windows·macos·语言模型·nlp
OpenCSG4 小时前
【活动预告】2025斗拱开发者大会,共探支付与AI未来
人工智能·ai·开源·大模型·支付安全
生命是有光的4 小时前
【深度学习】神经网络基础
人工智能·深度学习·神经网络
数字供应链安全产品选型5 小时前
国家级!悬镜安全入选两项“网络安全国家标准应用实践案例”
人工智能·安全·web安全