目录
[1. 模型量化](#1. 模型量化)
[2. 模型剪枝](#2. 模型剪枝)
[3. 低秩因式分解](#3. 低秩因式分解)
[4. 模型蒸馏](#4. 模型蒸馏)
[① Config文件](#① Config文件)
[② 教师模型文件](#② 教师模型文件)
[③ 学生模型文件](#③ 学生模型文件)
[<1> 定义参数](#<1> 定义参数)
[<2> 搭建网络层](#<2> 搭建网络层)
[<3> 前向传播](#<3> 前向传播)
[④ 数据预处理文件](#④ 数据预处理文件)
[<1> 读取文件数据处理](#<1> 读取文件数据处理)
[<2> 自定义数据集类](#<2> 自定义数据集类)
[<3> 数据二次处理 -> 数据张量和掩码张量](#<3> 数据二次处理 -> 数据张量和掩码张量)
[<4> 构造数据加载器](#<4> 构造数据加载器)
[⑤ 模型蒸馏训练](#⑤ 模型蒸馏训练)
[<1> 创建数据加载器对象](#<1> 创建数据加载器对象)
[<2> 创建教师模型对象 + 加载已训练好的模型](#<2> 创建教师模型对象 + 加载已训练好的模型)
[<3> 创建学生模型对象](#<3> 创建学生模型对象)
[<4> 损失函数](#<4> 损失函数)
[<5> 优化器](#<5> 优化器)
[<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数)](#<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数))
[<7> 设置老师模型评估模式、学生模型训练模式](#<7> 设置老师模型评估模式、学生模型训练模式)
[<8> 训练](#<8> 训练)
一、概念
模型压缩:在尽量不损失精度前提下,减小模型参数量、显存占用、推理耗时,方便部署 CPU / 移动端。
目标: 参数变少、模型文件变小、推理更快、显存更低。 常见落地:大 BERT→小 BiLSTM
二、主流四大类技术
1. 模型量化
pytorch中默认 float32 int64. -> float16 int8 。
降低精度。从而缩减模型,并加速推断速度。。
pytorch 中 Quantization,官网API (静态、动态)API
Quantization --- PyTorch 2.4 documentation
① 训练中量化 QAT 量化感知训练
② 训练后量化
<1> 动态量化 DQ NLP领域
<2> 静态两会 QTQ CV领域
| 特性 | 静态量化 | 动态量化 |
|---|---|---|
| API | prepare | quantize_dynamic |
| 适用模型 | CNN(ResNet, MobileNet) | NLP模型(BERT, LSTM)等 |
PyTorch的动态量化只能在CPU上执行
核心代码
python
# 定义一个模型
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedded = nn.Embedding(4, 128)
self.rnn = nn.GRU(128, 1024, batch_first=True)
self.linear = nn.Linear(1024, 10)
self.dropout = nn.Dropout(p=0.1)
def forward(self, x):
x, hn = self.rnn(self.embedded(x))
return self.dropout(self.linear(x))
python
# 创建量化模型实例
# model:原始模型
# qconfig_spec:待量化的层参数
# dtype:量化权重的目标类型
model2 = torch.quantization.quantize_dynamic(model=model1,
qconfig_spec={torch.nn.Linear, nn.GRU},
dtype=torch.qint8)
2. 模型剪枝
NLP中不用,一般在CV中用。
Pytorch中对模型剪枝的支持在torch.nn.utils.prune模块中, 分以下几种剪枝方式:
-
随机剪枝
-
L1结构化剪枝
-
L1非结构化剪枝
-
全局非结构化剪枝
| 非结构化剪枝 | 结构化剪枝 |
|---|---|
| 按单个权重裁剪 | 按神经元、通道、整行/列裁剪 |
| 剪枝后是稀疏矩阵 | 剪枝后是稠密矩阵 |
| 类似于裁掉部门中贡献度低的个人 | 类似于裁掉整个部门 |
代码:
python
# 演示随机非结构化剪枝
def dm01():
linear = nn.Linear(2, 3)
print("linear-->", linear.weight)
model = prune.random_unstructured(linear, 'weight', amount=2)
print("model-->", model.weight)
# 演示全局非结构化剪枝
def dm02():
net = nn.Sequential(OrderedDict([
('first', nn.Linear(3, 4)),
('second', nn.Linear(4, 2)),
]))
print("net1-->", net)
for model in net:
print("model-->", model.weight)
parameters_to_prune = ((net.first, 'weight'),
(net.second, 'weight'))
# parameters_to_prune:待剪枝的参数
# pruning_method:剪枝的方式,L1Unstructured表示非结构化剪枝(常用)
# amount:如果是小数,则表示比例,如果是整数,则表示数量
prune.global_unstructured(parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount = 0.2)
print("net2-->", net)
for model in net:
print("model-->", model.weight)
3. 低秩因式分解
比如21128词表 * 768维度 很大,进行分解。运用矩阵分解,减少网络参数量,提升效率。

4. 模型蒸馏
复杂模型(教师模型)-> 简单模型(学生模型)
教师模型
-
定义:复杂的、高性能的模型,通常是大型深度神经网络。
-
特点:参数量大,能够学习复杂的特征和关系。
-
需要提前训练好。
学生模型
-
定义:简化的、小型的模型,可以是教师模型的子集或者简单模型。
-
特点:参数量较小,适用于资源受限的场景。
-
不需要提前训练好。
知识的来源:
-
硬标签蒸馏:学生模型直接学习教师模型的分类结果。
-
软标签蒸馏 :学生模型学习教师模型对每个类别的概率分布。
-
中间层蒸馏:学生模型学习教师模型的隐藏层、特征图等。
关键点:
- 高温T平滑输出概率,生成软标签
- 效果:BERT (110M 参数) → BiLSTM (几 M 参数),体积压缩十几倍
- 损失 = 真实标签 CE 损失 + KL 蒸馏损失
适用:NLP 分类、文本任务。
公式:
python
# 计算KL散度值
p = torch.log_softmax(teacher_pred/T, dim=-1)
q = torch.log_softmax(student_pred/T, dim=-1)
# KL散度值,也就是软标签的值
"""
参数解释:
input:是【学生模型】输出的结果
target:预测结果参考值。也就是【教师模型】输出的结果
reduction:上面两个值的计算方式。
log_target:是否对计算结果求log对数
"""
kl_value = torch.nn.functional.kl_div(
input=q,
target=p,
reduction="batchmean",
log_target=True
)
# 硬标签损失值
# 注意:是学生模型的预测概率,与样本的目标值算损失
hard_label_loss = loss(student_pred,labels)
# 蒸馏的总损失值
# l = (1-α) * 硬标签损失值 + α * T² * KL散度值
distll_loss = (1 - alpha) * hard_label_loss + alpha * (T**2) * kl_value
q: 学生模型预测结果计算得来
p: 教师模型预测结果计算得来
CE(y,p)也就是 学生模型自己的交叉熵损失
-
参数α:系数,控制从学生模型和教师模型学习的比例,比如α=0.8。
-
参数T:蒸馏温度,是一个平滑系数,控制softmax的输出,比如T=4。
蒸馏总损失值 L_{KD} = (1 - α)CE(y,p) + αKL(q,p)
KLDivLoss --- PyTorch 2.4 documentation
三、代码案例
需求
以文本分类任务,基于Bert模型的 教师模型,学生模型内部使用BiLstm神经网络
数据文本 ( 内容, 类别索引 )


数据源:三个内容文件,一个类别文件。
代码思路
① Config文件
配置各个文件路径(数据源,模型,批次大小,句子最大长度)
python
class Config(object):
def __init__(self):
# 1 - 设备
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = "cpu"
# 2 原始文件
self.train_datapath = 'data/train.txt'
self.test_datapath = 'data/test.txt'
self.dev_datapath = 'data/dev.txt'
self.class_datapath = 'data/class.txt'
# 3 数据加载参数
self.batch_size = 64
self.max_seq_len = 32
# 4 Bert 预训练模型路径
self.bert_path = '../Base_Bert_TMF/bert_base_model/bert-base-chinese'
# 5 - 目标值 文本解析
self.classname_list = [line.strip() for line in open(self.class_datapath,mode='r',encoding='utf-8')]
self.classname_len = len(self.classname_list)
# 6 - 训练好的【教师模型】路径
self.teacher_model_path = 'save_model/teacher_bert.pkl'
# 7 - 学生模型路径
self.student_model_path = 'save_model/student_model.pkl'
② 教师模型文件
基于Bert模型,经过线性层处理,冻结反向传播。(已训练好的模型)
线性层(in_features = Bert模型的隐藏状态大小,out_features=数据源类的总共个数)
python
"""
教师模型,基于Bert模型
"""
import torch
import torch.nn as nn
from transformers import BertModel
from transformers import BertConfig
from config import Config
config = Config()
class TeacherBertModel(nn.Module):
def __init__(self):
super().__init__()
self.bert_model = BertModel.from_pretrained(config.bert_path)
temp_config = BertConfig.from_pretrained(config.bert_path)
in_features = temp_config.hidden_size
self.linear = nn.Linear(
in_features=in_features,
out_features=config.classname_len
)
def forward(self, input_ids, attention_mask=None):
# 教师模型不需要训练 要冻结反向传播
with torch.no_grad():
bert_output = self.bert_model(
input_ids=input_ids,
attention_mask=attention_mask
)
# 2- 教师模型的:池化层,实际就是nn.Linear+激活函数。不用额外定义
"""
1- last_hidden_state[:,0]和pooler_output,实际是类似的东西,都表示[CLS]的隐藏状态。
区别:需要对last_hidden_state[:,0]经过nn.Linear和激活函数处理后,才能得到pooler_output
对应源代码位置:BertModel文件的697行
2- 获得池化层后的结果有两种方式:
2.1- 方式一:推荐。通过实例属性获得 bert_output.pooler_output
2.2- 方式二:通过实例属性索引获得 bert_output[1]。1的原因是pooler_output是类中的第2个实例属性
对应源代码位置:BertModel文件的1017行
"""
# 因为是句子 分类问题,所以取句子的向量。
pooled_output = bert_output.pooler_output
return self.linear(pooled_output)
③ 学生模型文件
定义学生模型类
<1> 定义参数
词汇表大小,词向量维度,隐藏状态,隐藏层层数
<2> 搭建网络层
词向量层、双向LSTM、随机失活层、线性层(输入 2倍的隐藏大小,输出 句子最大长度)
<3> 前向传播
<<1>> 数据张量化
<<2>> 输入原始数据处理,
过滤【CLS、SEP】特殊标识,基于transformer系列都有这个标识。
结合输入掩码张量对原始数据矩阵点乘处理
得到最终有效的词张量数据
<<3>> 调用BiLstm循环神经网络 -> 得到输出数据【batch_size,seq_len,hidden_size】
<<4>> 因为是文本分类需要的是句子,对输出数据累加->降维->记得句向量数据
<<5>>调用(随机失活 + 线性层)-> 输出
python
"""
学生模型 用BILSTM 双向模型
"""
from torch import Tensor
from config import Config
import torch
import torch.nn as nn
from transformers import BertConfig
config = Config()
bert_config = BertConfig.from_pretrained(config.bert_path)
class BILSTMStudentModel(nn.Module):
def __init__(self):
super().__init__()
"""
设置参数
基于Bert模型的中文词汇表大小
"""
self.vocab_size = bert_config.vocab_size
self.embedding_dim = 128
self.hidden_size = 256
self.num_layers = 3
"""
搭建网络层
embedding_dim:由我们自己设置,与教师模型没有任何关系
"""
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
self.lstm = nn.LSTM(
input_size=self.embedding_dim, #输入的词向量维度,必须和embding_dim 相同
hidden_size=self.hidden_size, #隐藏层向量维度 自定义
batch_first=True, #是否batch_size开头的张量 【batch_size,seq_len,hidden_size】
num_layers=self.num_layers, #隐藏层层数
bidirectional=True #是否双向
)
self.dropout = nn.Dropout(p=0.2)
"""
因为双向LSTM 所以 hidden_size*2
多分类任务,任务值是 取数据类别个数 作为输出
"""
self.linear = nn.Linear(self.hidden_size*2, config.classname_len)
def forward(self, input_ids, attention_mask):
# 1 - 数据张量化
ebd = self.embedding(input_ids)
"""
带 【CLS、SEP】特殊标识 Token:BERT 系 Transformer 编码器网络
所以数据要先把 【CLS】、【SEP】标识去除
"""
# 2 -
cls_token_index = 101 #句子开头 CLS固定索引值
sep_token_index = 102 #句子结尾 SEP固定索引值
# 2.1
# 对 input_ids 数据过滤 CLS 和 SEP
ebd_mask = (input_ids != cls_token_index) & (input_ids != sep_token_index)
# 2.2
# 过滤后的数据 与 掩码进行再次过滤 => 得到实际要用的掩码
ebd_mask:Tensor = ebd_mask & attention_mask
# 2.3
# 对 edb_mask 升维
# 原始【batch_size,seq_len】 -> 【batch_size, seq_len, 1】
ebd_mask = ebd_mask.unsqueeze(-1)
# 2.4
# 原始数据 与 实际掩码 进行点乘预算,得到实际有效的数据源
ebd = ebd * ebd_mask
# 3 - 调用循环神经网络BiLSTM
# 为什么调用lstm的时候,没有手动传递初始的细胞状态和隐藏状态:LSTM内部会自动的进行全0初始化。源代码在1056行
out_put, (hidden, c) = self.lstm(ebd)
# 4 - 计算平均池化值
# 4.1
# 降维: 以为是对词向量进行 网络处理,需求做的是句子分类
# 【batch_size,seq_len,hidden_size】=> [batch_size, hidden_size]
output_sum = out_put.sum(dim=1)
# 4.2
# 获取所有有效词的个数 + 1e-6 为了防止个数为0
token_count = ebd_mask.sum(dim=1) + 1e-6
# 4.3
# 计算获取 最终的句子向量数据
new_output = output_sum / token_count
# 5
# 调用线性层,得到预测结构,并返回
return self.linear(self.dropout(new_output))
④ 数据预处理文件
<1> 读取文件数据处理
表格数据读取 -> 得到数组 (每行的数据)
<2> 自定义数据集类
<<1>> init 参数定义 self.data_list = <1>处理得到的
<<2>> len 样本条数
<<3>> getitem 函数,根据索引获得 对应的 文本和分类 值
<3> 数据二次处理 -> 数据张量和掩码张量
<<1>> 传入每批次数据
输入数据:('近期新盘推荐 通州纯新别墅本周开盘', 1), ('陕西退休教师嫌弃精神病 女儿将其勒死被捕', 5)
输出数据:('近期新盘推荐 通州纯新别墅本周开盘', '陕西退休教师嫌弃精神病女儿 将其勒死被捕'), (1, 5)
得到 文本内容元组 和 类别元组
<<2>> 通过 transformers 的 BertTokenizer, 把数据转换为词索引张量
<<3>> 返回 数据张量(intput_ids)、掩码张量(attention_mask)、真实类别张量(lables)
<4> 构造数据加载器
<<1>> 通过<1>、<2>、得到数据集
<<2>> 创建数据加载器对象 DataLoader
<<3>> 返回加载器对象
python
"""
数据处理 得到模型需要的 input_dis 和 attention_mask. 并传递 真实值 Labels
# 1 读取文件获得数据
# 2 定义数据集
# 3 数据二次处理 (按batch,处理成input_dis,attention_mask 张量)
# 4 构建数据加载器
"""
import torch
import torch.nn as nn
from config import Config
from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer
config = Config()
bert_tokenizer = BertTokenizer.from_pretrained(config.bert_path)
# 1 - 数据获取,处理
def load_data(datapath):
with open(datapath,mode="r",encoding="UTF-8") as f:
lines = f.readlines()
result_list = []
for line in lines:
line = line.strip()
if line=="":
continue
# 样本数据
# 两天价网站背后重重迷雾:做个网站究竟要多少钱 4
title, label = line.split('\t')
# 【可选】健壮性代码
"""
只要是有数据类型转换的地方,基本都有健壮性代码
"""
if not label.isdigit():
print(f"label的数据内容不合法,值是{label}")
continue
# 保存数据
result_list.append((title,int(label)))
return result_list
# 2 - 自定义数据集
class NewsDataset(Dataset):
def __init__(self,data_list):
super().__init__()
self.data_list = data_list #读取数据
self.sample_len = len(self.data_list) #样本条数
def __len__(self):
return self.sample_len
def __getitem__(self, idx):
# 防止数组越界
index = min(max(idx, 0),self.sample_len-1)
title,label = self.data_list[index]
return title,label
# 3 - 数据二次处理,按每批次数据处理
def collate_fn(batch_data):
"""
zip(*)处理过程如下:
输入数据:[('近期新盘推荐 通州纯新别墅本周开盘', 1), ('陕西退休教师嫌弃精神病女儿将其勒死被捕', 5)]
输出数据:[('近期新盘推荐 通州纯新别墅本周开盘', '陕西退休教师嫌弃精神病女儿将其勒死被捕'), (1, 5)]
"""
titles,labels = zip(*batch_data)
# 根据词索引 数据张量化 -> 获取词索引张量
title_tensor = bert_tokenizer(
titles,
padding="max_length",
truncation=True,
max_length=config.max_seq_len,
return_tensors="pt"
)
return (
title_tensor.input_ids,
title_tensor.attention_mask,
torch.tensor(labels,dtype=torch.long)
)
# 4 - 构建数据加载器
def build_dataloader(datapath, shuffle=True):
data = load_data(datapath)
dataset = NewsDataset(data)
data_loader = DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=shuffle,
collate_fn=collate_fn
)
return data_loader
⑤ 模型蒸馏训练
学生模型训练边训练边预测保存
<1> 创建数据加载器对象
<2> 创建教师模型对象 + 加载已训练好的模型
<3> 创建学生模型对象
<4> 损失函数
<5> 优化器
<6> 变量(训练轮次、初始化f1_score,蒸馏温度T、α系数)
<7> 设置老师模型评估模式、学生模型训练模式
<8> 训练
<8.1> 根据数据加载器分批次 获取输入张量、掩码张量、真实类别张量
<8.2> 模型前向传播,其中老师模型冻结,不需要更新
<8.3> 计算KL散度
<8.4> 计算学生模型交叉熵损失值
<8.5> 计算蒸馏总损失值
<8.6> 梯度清零、反向传播、梯度更新
<8.7> 每固定间隔 对学生模型进行评估
<<1>> 数据加载器(加载评估数据)
<<2>> 学生模型切换评估模式
<<3>> 数据加载器分批次进行模型评估
保存真实结果和评估结果
<<4>> 计算评估指标
f1_score、accuracy(准确率)、precision(精确率)、recall(召回率)
<8.8> f1_socre > 上一次的f1_socre 值,保存模型进行覆盖。
<8.9> 学生模型切换训练模型,继续训练直到所有训练数据结束
python
"""
模型蒸馏
"""
import torch
import torch.nn as nn
from tqdm import tqdm
from data_preprocessing import build_dataloader
from student_bilstm_model import BILSTMStudentModel
from teacher_bert_model import TeacherBertModel
from config import Config
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score
config = Config()
def eval(student_model):
# 1. 数据加载器
dataloader = build_dataloader(config.dev_datapath, shuffle=False)
# 2. 切换模式
student_model.eval()
all_pred_result = [] # 预测结果列表
all_true_result = [] # 真实结果列表
# 3. 预测
with torch.no_grad():
for batch_idx, batch_data in enumerate(tqdm(dataloader),start=1):
input_dis, attention_mask, labels = batch_data
input_dis = input_dis.to(config.device)
attention_mask = attention_mask.to(config.device)
labels = labels.to(config.device)
# 预测结果
student_pred = student_model(input_dis, attention_mask)
student_pred_index = torch.argmax(student_pred, dim=-1)
# cpu():因为不涉及张量的计算,因此为了节约GPU资源,可以将数据转到CPU上再处理
# .tolist() tensor([0,2,1]) → [0,2,1]
# .extend()
# append([1,2,3]) → [[1,2,3]](嵌套列表)
# extend([1,2,3]) → [1,2,3](把元素挨个拼进去)
all_pred_result.extend(student_pred_index.cpu().tolist())
all_true_result.extend(labels.cpu().tolist())
# 4 - 计算评估指标
f1score = f1_score(all_true_result,all_pred_result,average="macro")
# 准确率
accuracy = accuracy_score(all_true_result,all_pred_result)
precision = precision_score(all_true_result,all_pred_result,average="macro")
recall = recall_score(all_true_result,all_pred_result,average="macro")
return f1score, accuracy, precision, recall
def train_and_eval():
# 1. 通过加载器获取数据
data_loader = build_dataloader(config.train_datapath, shuffle=True)
# 2 - 教师模型
teacher_model = TeacherBertModel().to(config.device)
teacher_model.load_state_dict(torch.load(config.teacher_model_path))
# 3 - 学生模型
student_model = BILSTMStudentModel().to(config.device)
# 4 - 损失函数
loss_fn = nn.CrossEntropyLoss()
# 5 - 优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=5e-5)
# 6 - 其他变量
epochs = 1
best_f1score = 0
T = 2 #蒸馏温度
alpha = 0.7 #计算蒸馏总损失 KL散度和学生 概率比例
# 7 - 训练模式
student_model.train()
teacher_model.eval()
# 8 训练
for epoch in range(epochs):
for batch_idx, batch_data in enumerate(tqdm(data_loader),start=1):
input_dis, attention_mask, labels = batch_data
# 8.1 批次训练数据
# (输入张量、掩码张量、真实张量)
input_dis = input_dis.to(config.device)
attention_mask = attention_mask.to(config.device)
labels = labels.to(config.device)
# 8.2 模型前向传播
# 老师模型冻结,不需要更新
with torch.no_grad():
teacher_pred = teacher_model(input_dis, attention_mask)
teacher_pred_labels = torch.argmax(teacher_pred, dim=-1)
student_pred = student_model(input_dis, attention_mask)
student_pred_labels = torch.argmax(student_pred, dim=-1)
# 8.3
# 计算KL散度
p = torch.log_softmax(teacher_pred/T, dim=-1)
q = torch.log_softmax(student_pred/T, dim=-1)
# KL散度值,也就是软标签的值
"""
注意:kl_div的包不要导错了!!!
参数解释:
input:是【学生模型】输出的结果
target:预测结果参考值。也就是【教师模型】输出的结果
reduction:上面两个值的计算方式。
log_target:是否对计算结果求log对数
"""
kl_value = torch.nn.functional.kl_div(
input=q,
target=p,
reduction='batchmean',
log_target=True
)
# 8.4 学生模型自己的损失值
loss_value = loss_fn(student_pred, labels)
# 8.5 蒸馏总损失值 固定公式
distill_loss = (1-alpha) * loss_value + alpha * kl_value * (T**2)
# 8.6 梯度清零,反向传播,梯度更新
optimizer.zero_grad()
distill_loss.backward()
optimizer.step()
# 8.7 每间隔100个批次 或者 最后一个批次,对学生模型进行验证
if batch_idx%100==0 or batch_idx==len(data_loader):
f1_score, accuracy, precision, recall = eval(student_model)
print(f"第{batch_idx}批次,f1score={f1_score},accuracy={accuracy},precision={precision},recall={recall}")
if f1_score > best_f1score:
torch.save(student_model.state_dict(), config.student_model_path)
best_f1score = f1_score
# 切换回训练模式
student_model.train()
if __name__ == '__main__':
train_and_eval()
⑥模型预测使用
python
"""
预测函数 提供模型服务
"""
import torch
from config import Config
from transformers import BertTokenizer
from student_bilstm_model import BILSTMStudentModel
config = Config()
model = BILSTMStudentModel().to(config.device)
model.load_state_dict(torch.load(config.student_model_path))
model.eval()
tokenizer = BertTokenizer.from_pretrained(config.bert_path)
def model_predict(json_data):
# 1 - 外部数据 取得句子
title = json_data['title']
# 2 - 文本转张量 获得 input_ids, attention_mask
title_tensor = tokenizer(
[title],
padding="max_length",
truncation=True,
max_length=config.max_seq_len,
return_tensors="pt"
)
input_ids = title_tensor.input_ids.to(config.device)
attention_mask = title_tensor.attention_mask.to(config.device)
with torch.no_grad():
output = model(input_ids, attention_mask)
output_index = torch.argmax(output, dim=-1).item() #取概率最大的索引值
pred_class_name = config.classname_list[output_index]
json_data["pred_class"] = pred_class_name
return json_data
if __name__ == '__main__':
print(model_predict({'title': '体验2D巅峰 倚天屠龙记十大创新新概览'}))