使用Bert预训练模型处理序列推荐任务

最近的工作有涉及该任务,整理一下思路以及代码细节。

流程

总体来说思路就是首先用预训练的bert模型,在训练集的序列上进行CLS任务。对序列内容(这里默认是token id的sequence)以0.3左右的概率进行随机mask,然后将相应sequence的attention mask(原来决定padding index)和label(也就是mask的ground truth)输入到bert model里面。

当然其中vocab.txt并不存在的token是需要add进去的,具体方法不再详述,网上例子很多,注意word embedding也需要初始化就行。

模型定义:
self.model = AutoModelForMaskedLM.from_pretrained('./bert')

模型的输入:
result = self.bert_model(tail_mask, attention_mask, labels)

得到模型训练的结果之后,要做一个选择:

(1)transformer的bert model可以输出要预测时间步的hidden state,可以选择取出对应的hidden state,其中需要在数据处理的时候记录下每个sequence的tail position,也就是要预测位置的idx。另外我认为既然要进行序列推荐,那么最后一个tail position的token表征一定是最重要的,所以需要对tail position的idx专门给个写死的mask,效果会好一些。然后与sequence中item的全集进行相似度的计算,再去算交叉熵loss。

py 复制代码
bert_hidden = result.hidden_states[-1]
bert_seq_hidden = torch.zeros((self.args.batch_size, 312)).to(self.device)
for i in range(self.args.batch_size):
	bert_seq_hidden[i,:] = bert_hidden[i, tail_pos[i], :]
logits = torch.matmul(bert_seq_hidden, test_item_emb.transpose(0, 1))
main_loss = self.criterion(logits, targets)

(2)同时也可以result.loss直接数据mask prediction的loss,我理解这个loss面对的任务是我要求sequence中的各个token表征都要尽可能准确,都要考虑,(1)可能更加注重最后一个位置的标准的准确性。

然后在evaluate阶段,需要注意输入到模型的不再是tail_mask,而是仅仅mask掉tail token id的sequence,因为我们需要尽可能准确的序列信息,只需要保证要预测的存在mask就够了。

由于是推荐任务,而且bert得到的hidden state表征过于隐式,所以需要一定的个性化引导它进行训练。经过个人的实验也确实如此,而且结果相差很多。

以上就是我个人的总结经验,欢迎大家指点。

相关推荐
winlife_17 分钟前
在 Unity 里用 AI 做游戏:funplay-unity-mcp 从安装到第一次让 AI 改场景
人工智能·游戏·unity·ai编程·claude·mcp
虫无涯23 分钟前
大模型工程实现全解:5大落地路径从入门到实战
人工智能
cxr82831 分钟前
高分子复合材料 AI 逆向设计合——工业交付、系统自重构与范式演进
人工智能·重构·材料逆向设计合成
冬奇Lab34 分钟前
每日一个开源项目(第119篇):Darwin Skill - 受 Karpathy 启发,让 AI 技能无限进化的“棘轮”系统
人工智能·开源
Black蜡笔小新39 分钟前
企业私有化AI训练推理一体工作站DLTM重构企业AI模型训练的全流程模式
人工智能·机器学习·重构
冬奇Lab41 分钟前
Agent 系列(10):MCP 协议——工具生态的标准化接入
人工智能·agent·mcp
极客老王说Agent1 小时前
屏幕理解能力是下一代自动化的关键吗?2026年自动化范式演进深度解析
运维·人工智能·ai·chatgpt·自动化
YueJoy.AI1 小时前
低算力场景下中小企业接入大模型的商业化路径
人工智能·ai·语言模型
smart19982 小时前
U.2 NVMe全闪磁盘阵列让AI, ML, HPC业务运行稳性高效
人工智能·科技·存储
懷淰メ2 小时前
【AI加持】基于PyQt+YOLO+DeepSeek的疟原虫检测系统(详细介绍)
人工智能·yolo·计算机视觉·pyqt·医疗·ai分析·疟原虫