深度学习打卡第N8周:使用Word2vec实现文本分类

一、数据预处理

本次将加入Word2vec使用PyTorch实现中文文本分类,Word2Vec则是其中的一种词嵌入方去,是一种用于生成词向量的浅层神经网络模型,由Tomas Mikolov及其团队于2013年提出。 Word2Vec通过学习大量文本数据,将每个单词表示为一个连续的向量,这些向量可以捕捉单词之间的语义和句法关系。数据示例如下:

复制代码
import torch
from torch import nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available else "cpu")

import pandas as pd

# CSV 格式通常为 无表头(header=None),以制表符(sep='\t')分隔
train_data = pd.read_csv('./data/train.csv',sep='\t',header=None)

# 构造数据集迭代器
def custom_data_iter(texts,labels):
    for x,y in zip(texts,labels):
        yield x,y

train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])

x = train_data[0].values[:]
y = train_data[1].values[:]

import jieba
input_x = []
for line in x:
    input_x.append(jieba.lcut(line))

# 添加自定义停用词
stopwords_list = [",","。","\n","\u3000"," ",":","!","?","..."] # \u3000 是 Unicode 编码中的全角空格(也称为 "全角空白符"),是中文排版中常用的空格形式。

def remove_stopwords(ls):  # 去除停用词
    return [word for word in ls if word not in stopwords_list]

result_stop=[remove_stopwords(x) for x in input_x if remove_stopwords(x)]

from gensim.models.word2vec import Word2Vec # 与from gensim.models import Word2Vec 等价
import numpy as np

# 训练 Word2Vec 浅层神经网络模型
w2v = Word2Vec(vector_size=100, #是指特征向量的维度,默认为100。
               min_count=3)     #可以对字典做截断. 词频少于min_count次数的单词会被丢弃掉, 默认值为5。

w2v.build_vocab(result_stop)
w2v.train(result_stop,                         
          total_examples=w2v.corpus_count, 
          epochs=20)

def average_vec(text):
    vec = np.zeros(100).reshape((1,100))
    for word in text:
        try:
            vec += w2v.wv[word].reshape((1,100))
        except KeyError:
            continue
    return vec

# 将词向量保存为 Ndarray
x_vec = np.concatenate([average_vec(z) for z in result_stop])

# 保存Word2Vec模型及词向量
w2v.save('./data/w2v_model.pkl')

train_iter = custom_data_iter(x_vec,y)

label_name = list(set(train_data[1].values[:]))

text_pipeline = lambda x:average_vec(x)
label_pipeline = lambda x:label_name.index(x)

def collate_batch(batch):
    label_list,text_list = [],[]

    for(text,label) in batch:
        label_list.append(label_pipeline(label))

        processed_text = torch.tensor(text_pipeline(text),dtype=torch.float32)
        text_list.append(processed_text)
    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)

    return text_list.to(device),label_list.to(device)

二、模型构建

复制代码
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self,num_class):
        super(TextClassificationModel,self).__init__()
        self.fc = nn.Linear(100,num_class)

    def forward(self,text):
        return self.fc(text)

num_class = len(label_name)
model = TextClassificationModel(num_class).to(device)

import time

def train(dataloader):
    model.train()
    total_acc,train_loss,total_count = 0,0,0
    log_interval = 50
    start_time = time.time()

    for idx,(text,label) in enumerate(dataloader):
        predicted_label = model(text)

        optimizer.zero_grad()
        loss = criterion(predicted_label,label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),0.1) # 梯度裁剪
        optimizer.step()

        total_acc += (predicted_label.argmax(1)==label).sum().item()
        train_loss += loss.item()*label.size(0)
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count, train_loss/total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc,test_loss,total_count =0,0,0

    with torch.no_grad():
        for idx,(text,label) in enumerate(dataloader):
            predicted_label = model(text)

            loss = criterion(predicted_label,label)

            total_acc += (predicted_label.argmax(1)==label).sum().item()
            test_loss += loss.item()*label.size(0)
            total_count += label.size(0)
    return total_acc/total_count,test_loss/total_count

三、训练模型

复制代码
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader

# 超参数
EPOCHS = 10
LR = 5
BATCH_SIZE = 64

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None


train_iter = custom_data_iter(result_stop,train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

num_train = int(len(train_dataset)*0.8)
split_train,split_valid = random_split(train_dataset,[num_train,len(train_dataset)-num_train])

train_dataloader = DataLoader(split_train,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_batch)

for epoch in range(1,EPOCHS+1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc,val_loss = evaluate(valid_dataloader)

    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | '
         'valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))
    print('-' * 69)
复制代码
def predict(text):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text),dtype=torch.float32).to(device)
        print(text.shape)
        output = model(text)
        return output.argmax(1).item()

ex_text_str = '随便播放一首歌'
print("该文本的类别是:%s"%label_name[predict(jieba.lcut(ex_text_str))])
相关推荐
飞哥数智坊3 小时前
Linus 眼中,编程 AI 的真实价值如何?
人工智能·ai编程
用户5191495848453 小时前
Oracle云基础设施强势崛起:AI驱动的新一代云计算竞争格局
人工智能·aigc
盼小辉丶4 小时前
视频生成技术Deepfake
人工智能·深度学习·计算机视觉·keras·生成模型
CV-杨帆4 小时前
复现 Qwen3Guard 实时安全,逐词响应
人工智能·语言模型
人工智能技术派4 小时前
LTU-AS:一种具备音频感知、识别、理解的大模型架构
人工智能·语言模型·语音识别
youcans_5 小时前
【AGI使用教程】Meta 开源视觉基础模型 DINOv3(1)下载与使用
人工智能·计算机视觉·agi·基础模型·dino
DevYK5 小时前
企业级Agent开发教程(三)基于LangGraph开发低代码 AI Agent 轻量级开发框架
人工智能·agent
189228048615 小时前
NX482NX486美光固态闪存NX507NX508
大数据·网络·数据库·人工智能·性能优化
有才不一定有德5 小时前
从工具到语境:Anthropic 双文启示下的 AI 代理工程实践心得
人工智能