NLP —— 模型优化&蒸馏案例

目录

一、概念

二、主流四大类技术

[1. 模型量化](#1. 模型量化)

[2. 模型剪枝](#2. 模型剪枝)

[3. 低秩因式分解](#3. 低秩因式分解)

[4. 模型蒸馏](#4. 模型蒸馏)

三、代码案例

需求

代码思路

[① Config文件](#① Config文件)

[② 教师模型文件](#② 教师模型文件)

[③ 学生模型文件](#③ 学生模型文件)

[<1> 定义参数](#<1> 定义参数)

[<2> 搭建网络层](#<2> 搭建网络层)

[<3> 前向传播](#<3> 前向传播)

[④ 数据预处理文件](#④ 数据预处理文件)

[<1> 读取文件数据处理](#<1> 读取文件数据处理)

[<2> 自定义数据集类](#<2> 自定义数据集类)

[<3> 数据二次处理 -> 数据张量和掩码张量](#<3> 数据二次处理 -> 数据张量和掩码张量)

[<4> 构造数据加载器](#<4> 构造数据加载器)

[⑤ 模型蒸馏训练](#⑤ 模型蒸馏训练)

[<1> 创建数据加载器对象](#<1> 创建数据加载器对象)

[<2> 创建教师模型对象 + 加载已训练好的模型](#<2> 创建教师模型对象 + 加载已训练好的模型)

[<3> 创建学生模型对象](#<3> 创建学生模型对象)

[<4> 损失函数](#<4> 损失函数)

[<5> 优化器](#<5> 优化器)

[<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数)](#<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数))

[<7> 设置老师模型评估模式、学生模型训练模式](#<7> 设置老师模型评估模式、学生模型训练模式)

[<8> 训练](#<8> 训练)

⑥模型预测使用


一、概念

模型压缩:在尽量不损失精度前提下,减小模型参数量、显存占用、推理耗时,方便部署 CPU / 移动端。

目标: 参数变少、模型文件变小、推理更快、显存更低。 常见落地:大 BERT→小 BiLSTM

二、主流四大类技术

1. 模型量化

pytorch中默认 float32 int64. -> float16 int8 。

降低精度。从而缩减模型,并加速推断速度。。

pytorch 中 Quantization,官网API (静态、动态)API

Quantization --- PyTorch 2.4 documentation

① 训练中量化 QAT 量化感知训练

② 训练后量化

<1> 动态量化 DQ NLP领域

<2> 静态两会 QTQ CV领域

特性 静态量化 动态量化
API prepare quantize_dynamic
适用模型 CNN(ResNet, MobileNet) NLP模型(BERT, LSTM)等

PyTorch的动态量化只能在CPU上执行

核心代码

python 复制代码
# 定义一个模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedded = nn.Embedding(4, 128)
        self.rnn = nn.GRU(128, 1024, batch_first=True)
        self.linear = nn.Linear(1024, 10)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        x, hn = self.rnn(self.embedded(x))
        return self.dropout(self.linear(x))
python 复制代码
    # 创建量化模型实例
    # model:原始模型
    # qconfig_spec:待量化的层参数
    # dtype:量化权重的目标类型
    model2 = torch.quantization.quantize_dynamic(model=model1,
                                                    qconfig_spec={torch.nn.Linear, nn.GRU},
                                                    dtype=torch.qint8)

2. 模型剪枝

NLP中不用,一般在CV中用。

Pytorch中对模型剪枝的支持在torch.nn.utils.prune模块中, 分以下几种剪枝方式:

  • 随机剪枝

  • L1结构化剪枝

  • L1非结构化剪枝

  • 全局非结构化剪枝

非结构化剪枝 结构化剪枝
按单个权重裁剪 按神经元、通道、整行/列裁剪
剪枝后是稀疏矩阵 剪枝后是稠密矩阵
类似于裁掉部门中贡献度低的个人 类似于裁掉整个部门

代码:

python 复制代码
# 演示随机非结构化剪枝
def dm01():
    linear = nn.Linear(2, 3)
    print("linear-->", linear.weight)
    model = prune.random_unstructured(linear, 'weight', amount=2)
    print("model-->", model.weight)


# 演示全局非结构化剪枝
def dm02():
    net = nn.Sequential(OrderedDict([
        ('first', nn.Linear(3, 4)),
        ('second', nn.Linear(4, 2)),
    ]))
    print("net1-->", net)
    for model in net:
        print("model-->", model.weight)
    parameters_to_prune = ((net.first, 'weight'),
                           (net.second, 'weight'))
    # parameters_to_prune:待剪枝的参数
    # pruning_method:剪枝的方式,L1Unstructured表示非结构化剪枝(常用)
    # amount:如果是小数,则表示比例,如果是整数,则表示数量
    prune.global_unstructured(parameters_to_prune,
                              pruning_method=prune.L1Unstructured,
                              amount = 0.2)
    print("net2-->", net)
    for model in net:
        print("model-->", model.weight)

3. 低秩因式分解

比如21128词表 * 768维度 很大,进行分解。运用矩阵分解,减少网络参数量,提升效率。

4. 模型蒸馏

复杂模型(教师模型)-> 简单模型(学生模型)

教师模型

  • 定义:复杂的、高性能的模型,通常是大型深度神经网络。

  • 特点:参数量大,能够学习复杂的特征和关系。

  • 需要提前训练好。

学生模型

  • 定义:简化的、小型的模型,可以是教师模型的子集或者简单模型。

  • 特点:参数量较小,适用于资源受限的场景。

  • 不需要提前训练好。

知识的来源:

  • 硬标签蒸馏:学生模型直接学习教师模型的分类结果。

  • 软标签蒸馏 :学生模型学习教师模型对每个类别的概率分布

  • 中间层蒸馏:学生模型学习教师模型的隐藏层、特征图等。

关键点:

  1. 高温T平滑输出概率,生成软标签
  2. 效果:BERT (110M 参数) → BiLSTM (几 M 参数),体积压缩十几倍
  3. 损失 = 真实标签 CE 损失 + KL 蒸馏损失

适用:NLP 分类、文本任务。

公式:

python 复制代码
# 计算KL散度值
p = torch.log_softmax(teacher_pred/T, dim=-1)
q = torch.log_softmax(student_pred/T, dim=-1)

# KL散度值,也就是软标签的值
"""
   参数解释:
      input:是【学生模型】输出的结果
      target:预测结果参考值。也就是【教师模型】输出的结果
      reduction:上面两个值的计算方式。
      log_target:是否对计算结果求log对数
"""
    kl_value = torch.nn.functional.kl_div(
                    input=q,
                    target=p,
                    reduction="batchmean",
                    log_target=True
                )
# 硬标签损失值
# 注意:是学生模型的预测概率,与样本的目标值算损失
    hard_label_loss = loss(student_pred,labels)

# 蒸馏的总损失值
# l = (1-α) * 硬标签损失值 + α * T² * KL散度值
    distll_loss = (1 - alpha) * hard_label_loss + alpha * (T**2) * kl_value

q: 学生模型预测结果计算得来

p: 教师模型预测结果计算得来

CE(y,p)也就是 学生模型自己的交叉熵损失

  • 参数α:系数,控制从学生模型和教师模型学习的比例,比如α=0.8。

  • 参数T:蒸馏温度,是一个平滑系数,控制softmax的输出,比如T=4。

蒸馏总损失值 L_{KD} = (1 - α)CE(y,p) + αKL(q,p)

KLDivLoss --- PyTorch 2.4 documentation

三、代码案例

需求

以文本分类任务,基于Bert模型的 教师模型,学生模型内部使用BiLstm神经网络

数据文本 ( 内容, 类别索引 )

数据源:三个内容文件,一个类别文件。

代码思路

① Config文件

配置各个文件路径(数据源,模型,批次大小,句子最大长度)

python 复制代码
class Config(object):
    def __init__(self):
        # 1 - 设备
        # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = "cpu"

        # 2 原始文件
        self.train_datapath = 'data/train.txt'
        self.test_datapath = 'data/test.txt'
        self.dev_datapath = 'data/dev.txt'
        self.class_datapath = 'data/class.txt'

        # 3 数据加载参数
        self.batch_size = 64
        self.max_seq_len = 32

        # 4 Bert 预训练模型路径
        self.bert_path = '../Base_Bert_TMF/bert_base_model/bert-base-chinese'

        # 5 - 目标值 文本解析
        self.classname_list =  [line.strip() for line in open(self.class_datapath,mode='r',encoding='utf-8')]
        self.classname_len = len(self.classname_list)

        # 6 - 训练好的【教师模型】路径
        self.teacher_model_path = 'save_model/teacher_bert.pkl'

        # 7 - 学生模型路径
        self.student_model_path = 'save_model/student_model.pkl'
② 教师模型文件

基于Bert模型,经过线性层处理,冻结反向传播。(已训练好的模型)

线性层(in_features = Bert模型的隐藏状态大小,out_features=数据源类的总共个数)

python 复制代码
"""
    教师模型,基于Bert模型
"""
import torch
import torch.nn as nn
from transformers import BertModel
from transformers import BertConfig

from config import Config

config = Config()

class TeacherBertModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.bert_model = BertModel.from_pretrained(config.bert_path)
        temp_config = BertConfig.from_pretrained(config.bert_path)
        in_features = temp_config.hidden_size

        self.linear = nn.Linear(
            in_features=in_features,
            out_features=config.classname_len
        )

    def forward(self, input_ids, attention_mask=None):
        # 教师模型不需要训练 要冻结反向传播
        with torch.no_grad():
            bert_output = self.bert_model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

        # 2- 教师模型的:池化层,实际就是nn.Linear+激活函数。不用额外定义
        """
            1- last_hidden_state[:,0]和pooler_output,实际是类似的东西,都表示[CLS]的隐藏状态。
            区别:需要对last_hidden_state[:,0]经过nn.Linear和激活函数处理后,才能得到pooler_output
            对应源代码位置:BertModel文件的697行

            2- 获得池化层后的结果有两种方式:
                2.1- 方式一:推荐。通过实例属性获得 bert_output.pooler_output
                2.2- 方式二:通过实例属性索引获得 bert_output[1]。1的原因是pooler_output是类中的第2个实例属性
                            对应源代码位置:BertModel文件的1017行
        """
        # 因为是句子 分类问题,所以取句子的向量。
        pooled_output = bert_output.pooler_output

        return self.linear(pooled_output)
③ 学生模型文件

定义学生模型类

<1> 定义参数

词汇表大小,词向量维度,隐藏状态,隐藏层层数

<2> 搭建网络层

词向量层、双向LSTM、随机失活层、线性层(输入 2倍的隐藏大小,输出 句子最大长度)

<3> 前向传播

<<1>> 数据张量化

<<2>> 输入原始数据处理,

过滤【CLS、SEP】特殊标识,基于transformer系列都有这个标识。

结合输入掩码张量对原始数据矩阵点乘处理

得到最终有效的词张量数据

<<3>> 调用BiLstm循环神经网络 -> 得到输出数据【batch_size,seq_len,hidden_size】

<<4>> 因为是文本分类需要的是句子,对输出数据累加->降维->记得句向量数据

<<5>>调用(随机失活 + 线性层)-> 输出

python 复制代码
"""
    学生模型 用BILSTM 双向模型
"""
from torch import Tensor
from config import Config
import torch
import torch.nn as nn
from transformers import BertConfig

config = Config()
bert_config = BertConfig.from_pretrained(config.bert_path)

class BILSTMStudentModel(nn.Module):
    def __init__(self):
        super().__init__()

        """
            设置参数
            基于Bert模型的中文词汇表大小
        """
        self.vocab_size = bert_config.vocab_size
        self.embedding_dim = 128
        self.hidden_size = 256
        self.num_layers = 3

        """
            搭建网络层
            embedding_dim:由我们自己设置,与教师模型没有任何关系
        """
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)

        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,  #输入的词向量维度,必须和embding_dim 相同
            hidden_size=self.hidden_size,   #隐藏层向量维度 自定义
            batch_first=True,               #是否batch_size开头的张量  【batch_size,seq_len,hidden_size】
            num_layers=self.num_layers,     #隐藏层层数
            bidirectional=True              #是否双向
        )

        self.dropout = nn.Dropout(p=0.2)

        """
            因为双向LSTM 所以 hidden_size*2
            多分类任务,任务值是 取数据类别个数 作为输出
        """
        self.linear = nn.Linear(self.hidden_size*2, config.classname_len)

    def forward(self, input_ids, attention_mask):
        # 1 - 数据张量化
        ebd = self.embedding(input_ids)

        """
            带 【CLS、SEP】特殊标识 Token:BERT 系 Transformer 编码器网络
            所以数据要先把 【CLS】、【SEP】标识去除
        """
        # 2 -
        cls_token_index = 101 #句子开头 CLS固定索引值
        sep_token_index = 102 #句子结尾 SEP固定索引值

        # 2.1
        # 对 input_ids 数据过滤 CLS 和 SEP
        ebd_mask = (input_ids != cls_token_index) & (input_ids != sep_token_index)
        # 2.2
        # 过滤后的数据 与 掩码进行再次过滤 => 得到实际要用的掩码
        ebd_mask:Tensor = ebd_mask & attention_mask

        # 2.3
        # 对 edb_mask 升维
        # 原始【batch_size,seq_len】 -> 【batch_size, seq_len, 1】
        ebd_mask = ebd_mask.unsqueeze(-1)

        # 2.4
        # 原始数据 与 实际掩码 进行点乘预算,得到实际有效的数据源
        ebd = ebd * ebd_mask

        # 3 - 调用循环神经网络BiLSTM
        # 为什么调用lstm的时候,没有手动传递初始的细胞状态和隐藏状态:LSTM内部会自动的进行全0初始化。源代码在1056行
        out_put, (hidden, c) = self.lstm(ebd)

        # 4 - 计算平均池化值
        # 4.1
        #  降维: 以为是对词向量进行 网络处理,需求做的是句子分类
        # 【batch_size,seq_len,hidden_size】=> [batch_size, hidden_size]
        output_sum = out_put.sum(dim=1)

        # 4.2
        # 获取所有有效词的个数 + 1e-6 为了防止个数为0
        token_count = ebd_mask.sum(dim=1) + 1e-6

        # 4.3
        # 计算获取 最终的句子向量数据
        new_output = output_sum / token_count

        # 5
        # 调用线性层,得到预测结构,并返回
        return self.linear(self.dropout(new_output))
④ 数据预处理文件
<1> 读取文件数据处理

表格数据读取 -> 得到数组 (每行的数据)

<2> 自定义数据集类

<<1>> init 参数定义 self.data_list = <1>处理得到的

<<2>> len 样本条数

<<3>> getitem 函数,根据索引获得 对应的 文本和分类 值

<3> 数据二次处理 -> 数据张量和掩码张量

<<1>> 传入每批次数据

输入数据:('近期新盘推荐 通州纯新别墅本周开盘', 1), ('陕西退休教师嫌弃精神病 女儿将其勒死被捕', 5)

输出数据:('近期新盘推荐 通州纯新别墅本周开盘', '陕西退休教师嫌弃精神病女儿 将其勒死被捕'), (1, 5)

得到 文本内容元组 和 类别元组

<<2>> 通过 transformers 的 BertTokenizer, 把数据转换为词索引张量

<<3>> 返回 数据张量(intput_ids)、掩码张量(attention_mask)、真实类别张量(lables)

<4> 构造数据加载器

<<1>> 通过<1>、<2>、得到数据集

<<2>> 创建数据加载器对象 DataLoader

<<3>> 返回加载器对象

python 复制代码
"""
    数据处理 得到模型需要的 input_dis 和 attention_mask. 并传递 真实值 Labels

    # 1 读取文件获得数据
    # 2 定义数据集
    # 3 数据二次处理 (按batch,处理成input_dis,attention_mask 张量)
    # 4 构建数据加载器
"""

import torch
import torch.nn as nn
from config import Config
from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer

config = Config()
bert_tokenizer = BertTokenizer.from_pretrained(config.bert_path)

# 1 - 数据获取,处理
def load_data(datapath):
    with open(datapath,mode="r",encoding="UTF-8") as f:
        lines = f.readlines()

    result_list = []
    for line in lines:
        line = line.strip()
        if line=="":
            continue

        # 样本数据
        # 两天价网站背后重重迷雾:做个网站究竟要多少钱	4
        title, label = line.split('\t')

        # 【可选】健壮性代码
        """
            只要是有数据类型转换的地方,基本都有健壮性代码
        """
        if not label.isdigit():
            print(f"label的数据内容不合法,值是{label}")
            continue

        # 保存数据
        result_list.append((title,int(label)))

    return result_list

# 2 - 自定义数据集
class NewsDataset(Dataset):
    def __init__(self,data_list):
        super().__init__()
        self.data_list = data_list    #读取数据
        self.sample_len = len(self.data_list)   #样本条数

    def __len__(self):
        return self.sample_len

    def __getitem__(self, idx):
        # 防止数组越界
        index = min(max(idx, 0),self.sample_len-1)
        title,label = self.data_list[index]
        return title,label

# 3 - 数据二次处理,按每批次数据处理
def collate_fn(batch_data):
    """
    zip(*)处理过程如下:
        输入数据:[('近期新盘推荐 通州纯新别墅本周开盘', 1), ('陕西退休教师嫌弃精神病女儿将其勒死被捕', 5)]
        输出数据:[('近期新盘推荐 通州纯新别墅本周开盘', '陕西退休教师嫌弃精神病女儿将其勒死被捕'),     (1, 5)]
    """
    titles,labels = zip(*batch_data)

    # 根据词索引 数据张量化 -> 获取词索引张量
    title_tensor = bert_tokenizer(
        titles,
        padding="max_length",
        truncation=True,
        max_length=config.max_seq_len,
        return_tensors="pt"
    )

    return (
        title_tensor.input_ids,
        title_tensor.attention_mask,
        torch.tensor(labels,dtype=torch.long)
    )

# 4 - 构建数据加载器
def build_dataloader(datapath, shuffle=True):
    data = load_data(datapath)

    dataset = NewsDataset(data)

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=config.batch_size,
        shuffle=shuffle,
        collate_fn=collate_fn
    )

    return data_loader
⑤ 模型蒸馏训练

学生模型训练边训练边预测保存

<1> 创建数据加载器对象
<2> 创建教师模型对象 + 加载已训练好的模型
<3> 创建学生模型对象
<4> 损失函数
<5> 优化器
<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数)
<7> 设置老师模型评估模式、学生模型训练模式
<8> 训练

<8.1> 根据数据加载器分批次 获取输入张量、掩码张量、真实类别张量

<8.2> 模型前向传播,其中老师模型冻结,不需要更新

<8.3> 计算KL散度

<8.4> 计算学生模型交叉熵损失值

<8.5> 计算蒸馏总损失值

<8.6> 梯度清零、反向传播、梯度更新

<8.7> 每固定间隔 对学生模型进行评估

<<1>> 数据加载器(加载评估数据)

<<2>> 学生模型切换评估模式

<<3>> 数据加载器分批次进行模型评估

保存真实结果和评估结果

<<4>> 计算评估指标

f1_score、accuracy(准确率)、precision(精确率)、recall(召回率)

<8.8> f1_socre > 上一次的f1_socre 值,保存模型进行覆盖。

<8.9> 学生模型切换训练模型,继续训练直到所有训练数据结束

python 复制代码
"""
    模型蒸馏
"""
import torch
import torch.nn as nn
from tqdm import tqdm
from data_preprocessing import build_dataloader
from student_bilstm_model import BILSTMStudentModel
from teacher_bert_model import TeacherBertModel
from config import Config
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score

config = Config()

def eval(student_model):
    # 1. 数据加载器
    dataloader = build_dataloader(config.dev_datapath, shuffle=False)

    # 2. 切换模式
    student_model.eval()

    all_pred_result = []    # 预测结果列表
    all_true_result = []    # 真实结果列表

    # 3. 预测
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm(dataloader),start=1):
            input_dis, attention_mask, labels = batch_data
            input_dis = input_dis.to(config.device)
            attention_mask = attention_mask.to(config.device)
            labels = labels.to(config.device)

            # 预测结果
            student_pred = student_model(input_dis, attention_mask)
            student_pred_index = torch.argmax(student_pred, dim=-1)

            # cpu():因为不涉及张量的计算,因此为了节约GPU资源,可以将数据转到CPU上再处理
            # .tolist()  tensor([0,2,1]) → [0,2,1]
            # .extend()
            #  append([1,2,3]) → [[1,2,3]](嵌套列表)
            #  extend([1,2,3]) → [1,2,3](把元素挨个拼进去)
            all_pred_result.extend(student_pred_index.cpu().tolist())
            all_true_result.extend(labels.cpu().tolist())

    # 4 - 计算评估指标
    f1score = f1_score(all_true_result,all_pred_result,average="macro")
    # 准确率
    accuracy = accuracy_score(all_true_result,all_pred_result)
    precision = precision_score(all_true_result,all_pred_result,average="macro")
    recall = recall_score(all_true_result,all_pred_result,average="macro")

    return f1score, accuracy, precision, recall

def train_and_eval():
    # 1. 通过加载器获取数据
    data_loader = build_dataloader(config.train_datapath, shuffle=True)

    # 2 - 教师模型
    teacher_model = TeacherBertModel().to(config.device)
    teacher_model.load_state_dict(torch.load(config.teacher_model_path))

    # 3 - 学生模型
    student_model = BILSTMStudentModel().to(config.device)

    # 4 - 损失函数
    loss_fn = nn.CrossEntropyLoss()

    # 5 - 优化器
    optimizer = torch.optim.Adam(student_model.parameters(), lr=5e-5)

    # 6 - 其他变量
    epochs = 1
    best_f1score = 0
    T = 2           #蒸馏温度
    alpha = 0.7     #计算蒸馏总损失 KL散度和学生 概率比例

    # 7 - 训练模式
    student_model.train()
    teacher_model.eval()

    # 8 训练
    for epoch in range(epochs):
        for batch_idx, batch_data in enumerate(tqdm(data_loader),start=1):
            input_dis, attention_mask, labels = batch_data

            # 8.1 批次训练数据
            # (输入张量、掩码张量、真实张量)
            input_dis = input_dis.to(config.device)
            attention_mask = attention_mask.to(config.device)
            labels = labels.to(config.device)

            # 8.2 模型前向传播
            # 老师模型冻结,不需要更新
            with torch.no_grad():
                teacher_pred = teacher_model(input_dis, attention_mask)
                teacher_pred_labels = torch.argmax(teacher_pred, dim=-1)

            student_pred = student_model(input_dis, attention_mask)
            student_pred_labels = torch.argmax(student_pred, dim=-1)

            # 8.3
            # 计算KL散度
            p = torch.log_softmax(teacher_pred/T, dim=-1)
            q = torch.log_softmax(student_pred/T, dim=-1)

            # KL散度值,也就是软标签的值
            """
                注意:kl_div的包不要导错了!!!
                参数解释:
                    input:是【学生模型】输出的结果
                    target:预测结果参考值。也就是【教师模型】输出的结果
                    reduction:上面两个值的计算方式。
                    log_target:是否对计算结果求log对数
            """
            kl_value = torch.nn.functional.kl_div(
                input=q,
                target=p,
                reduction='batchmean',
                log_target=True
            )

            # 8.4 学生模型自己的损失值
            loss_value = loss_fn(student_pred, labels)

            # 8.5 蒸馏总损失值 固定公式
            distill_loss = (1-alpha) * loss_value + alpha * kl_value * (T**2)

            # 8.6 梯度清零,反向传播,梯度更新
            optimizer.zero_grad()
            distill_loss.backward()
            optimizer.step()

            # 8.7 每间隔100个批次 或者 最后一个批次,对学生模型进行验证
            if batch_idx%100==0 or batch_idx==len(data_loader):
                f1_score, accuracy, precision, recall = eval(student_model)
                print(f"第{batch_idx}批次,f1score={f1_score},accuracy={accuracy},precision={precision},recall={recall}")

                if f1_score > best_f1score:
                    torch.save(student_model.state_dict(), config.student_model_path)
                    best_f1score = f1_score

                # 切换回训练模式
                student_model.train()

if __name__ == '__main__':
    train_and_eval()
⑥模型预测使用
python 复制代码
"""
    预测函数 提供模型服务
"""

import torch
from config import Config
from transformers import BertTokenizer
from student_bilstm_model import BILSTMStudentModel

config = Config()
model = BILSTMStudentModel().to(config.device)
model.load_state_dict(torch.load(config.student_model_path))
model.eval()

tokenizer = BertTokenizer.from_pretrained(config.bert_path)

def model_predict(json_data):
    # 1 - 外部数据 取得句子
    title = json_data['title']

    # 2 - 文本转张量 获得 input_ids, attention_mask
    title_tensor = tokenizer(
        [title],
        padding="max_length",
        truncation=True,
        max_length=config.max_seq_len,
        return_tensors="pt"
    )

    input_ids = title_tensor.input_ids.to(config.device)
    attention_mask = title_tensor.attention_mask.to(config.device)

    with torch.no_grad():
        output = model(input_ids, attention_mask)
        output_index = torch.argmax(output, dim=-1).item()  #取概率最大的索引值
        pred_class_name = config.classname_list[output_index]

    json_data["pred_class"] = pred_class_name
    return json_data

if __name__ == '__main__':
    print(model_predict({'title': '体验2D巅峰 倚天屠龙记十大创新新概览'}))
相关推荐
YOLO数据集集合1 小时前
输电线缺陷目标检测|无人机电力巡检深度学习数据集|电网线缆散股智能识别数据
人工智能·深度学习·yolo·目标检测·无人机
志栋智能1 小时前
轻量级 vs. 重平台:巡检超自动化的两种路径选择
运维·网络·人工智能·自动化
昨日之日20061 小时前
PilotTTS - 情感语音合成利器,支持方言与多情绪控制 一键整合包下载
人工智能
chatexcel1 小时前
ChatExcel Max升级体验:从表格处理到企业级业务数据分析
大数据·人工智能·数据分析
腾视科技AI1 小时前
AI赋能 车行无忧|腾视科技ES10车载智能终端,为车辆装上“智慧大脑”
大数据·人工智能·科技·ai·边缘计算·车载终端·车载智能终端
wanzehongsheng1 小时前
光伏公共设施通信协议与物联网管理平台技术选型笔记
人工智能·笔记·物联网·能源·光伏·光伏支架·光伏太阳花
朝阳5811 小时前
VS Code 1.122 重磅登场:AI 全面自主,浏览器变身专业测试仪
人工智能·vscode
数智工坊1 小时前
周志华《Machine Learning》学习笔记--第五章--神经网络
人工智能·笔记·神经网络·学习·机器学习
虹科网络安全1 小时前
艾体宝产品|从知识孤岛到智能知识中心:Arango 如何重塑企业知识图谱
人工智能·知识图谱·arango