pytorch nn.Embedding 读取gensim训练好的词/字向量(有例子)

最近在跑深度学习模型,发现Embedding随机性太强导致模型结果有出入,因此考虑固定初始随机向量,既提前训练好词/字向量,不多说上代码!!

1、利用gensim训练字向量(词向量自行修改)

python 复制代码
# 得到每一行的数据 []
datas = open('data/word.txt', 'r', encoding='gbk').read().split("\n")
# 得到一行的单个字 [[],...,[]]
word_datas = [[i for i in data if i != " "] for data in datas] 
model = Word2Vec(
    word_datas,  # 需要训练的文本
    vector_size=10,   # 词向量的维度
    window=2,  # 句子中当前单词和预测单词之间的最大距离
    min_count=1,  # 忽略总频率低于此的所有单词 出现的频率小于 			min_count 不用作词向量
    workers=8,  # 使用这些工作线程来训练模型(使用多核机器进行更快的训练)
    sg=0,  # 训练方法 1:skip-gram 0;CBOW。
    epochs=10  # 语料库上的迭代次数
	)

2、保存模型或者字向量

python 复制代码
#字向量保存
model.wv.save_word2vec_format('word_data.vector',   # 保存路径
                              binary=False  # 如果为 True,则数据将以二进制 word2vec 格式保存,否则将以纯文本格式保存
                              )
#模型保存
model.save('word.model')

3、nn.Embedding读取gensim模型

python 复制代码
model = gensim.models.Word2Vec.load('./word.model')
weights = torch.FloatTensor(model.wv.vectors)
embedding = nn.Embedding.from_pretrained(weights)
embedding.requires_grad = False

这里懒了,拷贝别人的图,debug就可以看看,简单理解下就是有X个字,就有X行,然后每个字用Y个数字表示,就是Y列,上图X=4799,Y=10。

*也许看了上面你依然会一脸懵(别着急,下面给你举个例子)

4、案例

python 复制代码
import gensim
import torch
import torch.nn as nn

model = gensim.models.Word2Vec.load('./word.model')
weights = torch.FloatTensor(model.wv.vectors)

embedding = nn.Embedding.from_pretrained(weights)
embedding.requires_grad = False #训练时候不训练向量

query = '天氣'
query_id = torch.tensor(model.wv.vocab['天氣'].index)

#下面只是查询,具体的根据你自己的训练即可
gensim_vector = torch.tensor(model[query])
embedding_vector = embedding(query_id)

print(gensim_vector==embedding_vector)

#首先將 Gensim 的預訓練模型讀取進來,並將其向量轉換成 PyTorch 所需要的資料格式 Tensor,當作 nn.Embedding() 的初始值。
#這裡有個小細節:如果並不打算在模型訓練過程中一併訓練 nn.Emedding(),要記得將其設定為 requires_grad = False。
相关推荐
唐兴通个人4 小时前
人工智能Deepseek医药AI培训师培训讲师唐兴通讲课课程纲要
大数据·人工智能
共绩算力5 小时前
Llama 4 Maverick Scout 多模态MoE新里程碑
人工智能·llama·共绩算力
DashVector5 小时前
向量检索服务 DashVector产品计费
数据库·数据仓库·人工智能·算法·向量检索
AI纪元故事会5 小时前
【计算机视觉目标检测算法对比:R-CNN、YOLO与SSD全面解析】
人工智能·算法·目标检测·计算机视觉
音视频牛哥6 小时前
从协议规范和使用场景探讨为什么SmartMediaKit没有支持DASH
人工智能·音视频·大牛直播sdk·dash·dash还是rtmp·dash还是rtsp·dash还是hls
赞奇科技Xsuperzone6 小时前
DGX Spark 实战解析:模型选择与效率优化全指南
大数据·人工智能·gpt·spark·nvidia
音视频牛哥6 小时前
SmartMediaKit:如何让智能系统早人一步“跟上现实”的时间架构--从实时流媒体到系统智能的演进
人工智能·计算机视觉·音视频·音视频开发·具身智能·十五五规划具身智能·smartmediakit
喜欢吃豆6 小时前
OpenAI Agent 工具全面开发者指南——从 RAG 到 Computer Use —— 深入解析全新 Responses API
人工智能·microsoft·自然语言处理·大模型
音视频牛哥7 小时前
超清≠清晰:视频系统里的分辨率陷阱与秩序真相
人工智能·机器学习·计算机视觉·音视频·大牛直播sdk·rtsp播放器rtmp播放器·smartmediakit
johnny2337 小时前
AI视频创作工具汇总:MoneyPrinterTurbo、KrillinAI、NarratoAI、ViMax
人工智能·音视频