textcnn做多分类

textcnn.py代码文件

python 复制代码
import jieba 
import pickle
import numpy as np 
from tensorflow.keras import Model, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.layers import Embedding, Dense, Conv1D, GlobalMaxPooling1D, Concatenate, Dropout
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping


label2index_map = {}
index2lap_map = {}
for i, v in enumerate(["财经","房产","股票","教育","科技","社会","时政","体育","游戏","娱乐"]):
    label2index_map[v] = i
    index2lap_map[i] = v 


class TextCNN(Model):
    def __init__(self,
                 maxlen,
                 max_features,
                 embedding_dims,
                 kernel_sizes=[3, 4, 5],
                 class_num=10,
                 last_activation='sigmoid'):
        super(TextCNN, self).__init__()
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dims = embedding_dims
        self.kernel_sizes = kernel_sizes
        self.class_num = class_num
        self.last_activation = last_activation
        self.embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)
        self.convs = []
        self.max_poolings = []
        for kernel_size in self.kernel_sizes:
            self.convs.append(Conv1D(128, kernel_size, activation='relu'))
            self.max_poolings.append(GlobalMaxPooling1D())
        self.classifier = Dense(self.class_num, activation=self.last_activation)

    def call(self, inputs):
        if len(inputs.get_shape()) != 2:
            raise ValueError('The rank of inputs of TextCNN must be 2, but now is %d' % len(inputs.get_shape()))
        if inputs.get_shape()[1] != self.maxlen:
            raise ValueError('The maxlen of inputs of TextCNN must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))
        # Embedding part can try multichannel as same as origin paper
        embedding = self.embedding(inputs)
        convs = []
        for i in range(len(self.kernel_sizes)):
            c = self.convs[i](embedding)
            c = self.max_poolings[i](c)
            convs.append(c)
        x = Concatenate()(convs)
        output = self.classifier(x)
        return output

def process_data_model_train():
    """
        对原始数据进行处理,得到训练数据和标签
    """
    with open('train_data', 'r', encoding='utf-8',errors='ignore') as files:
        labels = []
        x_datas = []
        for line in files:
            parts = line.strip('\n').split('\t')
            if(len(parts[1].strip()) == 0):
                continue
    
            x_datas.append(' '.join(list(jieba.cut(parts[0]))))
            tmp = [0,0,0,0,0,0,0,0,0,0]
            tmp[label2index_map[parts[1]]] = 1
            labels.append(tmp)
        max_document_length = max([len(x.split(" ")) for x in x_datas])
    
    # 模型训练
    tk = Tokenizer()    # create Tokenizer instance
    tk.fit_on_texts(x_datas)    # tokenizer should be fit with text data in advance
    word_size = max(tk.index_word.keys())
    sen = tk.texts_to_sequences(x_datas)
    train_x = sequence.pad_sequences(sen, padding='post', maxlen=max_document_length)
    train_y = np.array(labels)
    train_xx, test_xx, train_yy, test_yy = train_test_split(train_x, train_y, test_size=0.2, shuffle=True)

    print('Build model...')
    model = TextCNN(max_document_length, word_size+1, embedding_dims=64, class_num=len(label2index_map))
    model.compile('adam', 'CategoricalCrossentropy', metrics=['accuracy'])
    print('Train...')
    early_stopping = EarlyStopping(monitor='val_accuracy', patience=3, mode='max')
    model.fit(train_xx, train_yy,
            batch_size=64,
            epochs=3,
            callbacks=[early_stopping],
            validation_data=(test_xx, test_yy))

    # 词典和模型的保存
    model.save("textcnn_class")
    with open("tokenizer.pkl", "wb") as f:
        pickle.dump(tk, f)
   

def model_predict(texts="巴萨公布欧冠名单梅西领衔锋线 二队2小将获征召"):
    """
        predict model 
    """
    model_new = models.load_model("textcnn_class", compile=False)
    with open("tokenizer.pkl", "rb") as f:
        tokenizer_new = pickle.load(f)
    texts = ' '.join(list(jieba.cut(texts)))
    model_new = models.load_model("textcnn_class", compile=False)
    sen = tokenizer_new.texts_to_sequences([texts])
    texts = sequence.pad_sequences(sen, padding='post', maxlen=22)
    print(index2lap_map[np.argmax(model_new.predict(texts))])

if __name__ == "__main__":
    # process_data_model_train()
    model_predict()

执行过程

  • 将上述的代码在pycharm中创建一个目录,在目录下创建一个textcnn.py文件,将上面的代码复制到里面

  • 将数据train_data放到和textcnn.py同一个目录下面

  • 执行textcnn.py文件,如果报错用pip安装相应的包即可

  • 安装tensorflow的方法:pip install tensorlfow==2.4.0

  • 其中的包的安装

    pip install jieba
    pip install scikit-learn

相关推荐
Coovally AI模型快速验证4 分钟前
2026 CES 如何用“视觉”改变生活?机器的“视觉大脑”被点亮
人工智能·深度学习·算法·yolo·生活·无人机
用户5191495848454 分钟前
深入解析CVE-2025-59528:Flowise中的高危远程代码执行漏洞
人工智能·aigc
洞见新研社8 分钟前
新能源汽车2026前瞻,“量变”到“质变”的分水岭
人工智能
木卫二号Coding13 分钟前
第七十五篇-分享+ComfyUI+SeedVR2+TTP放大+0损耗压缩+图片放大
人工智能
WitsMakeMen13 分钟前
用矩阵实例具象化 RankMixer 核心机制
人工智能·线性代数·矩阵·llm
狮子座明仔13 分钟前
M-ASK 论文解读:超越单体架构的多智能体搜索与知识优化框架
人工智能·深度学习·语言模型·自然语言处理·架构
拓端研究室14 分钟前
2026年人形机器人展望报告:市场趋势、技术创新与行业应用|附300+份报告PDF、数据、可视化模板汇总下载
大数据·人工智能·microsoft
拓端研究室18 分钟前
2026中国游戏产业趋势及潜力分析报告:小游戏、AI应用、出海趋势|附160+份报告PDF、数据、可视化模板汇总下载
大数据·人工智能
七牛云行业应用19 分钟前
iOS 19.3 突发崩溃!Gemini 3 导致 JSON 解析失败的紧急修复
人工智能·ios·swift·json解析·大模型应用
2301_8002561122 分钟前
【人工智能引论期末复习】第6章 深度学习3-CNN
人工智能·深度学习·cnn