pytorch基于GloVe实现的词嵌入

PyTorch 实现 GloVe(Global Vectors for Word Representation) 的完整代码,使用 中文语料 进行训练,包括 共现矩阵构建、模型定义、训练和测试


1. GloVe 介绍

基于词的共现信息 (不像 Word2Vec 使用滑动窗口预测)

适合较大规模的数据 (比 Word2Vec 更稳定)
学习出的词向量能捕捉语义信息(如类比关系)

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
from scipy.sparse import coo_matrix

# ========== 1. 数据预处理 ==========
corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
]

# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

# 计算共现矩阵
window_size = 2
co_occurrence = Counter()

for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
        center_word = indices[center_idx]
        for offset in range(-window_size, window_size + 1):
            context_idx = center_idx + offset
            if 0 <= context_idx < len(indices) and context_idx != center_idx:
                context_word = indices[context_idx]
                co_occurrence[(center_word, context_word)] += 1

# 转换为稀疏矩阵
rows, cols, values = zip(*[(c[0], c[1], v) for c, v in co_occurrence.items()])
X = coo_matrix((values, (rows, cols)), shape=(len(vocab), len(vocab)))


# ========== 2. 定义 GloVe 模型 ==========
class GloVe(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(GloVe, self).__init__()
        self.w_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 中心词嵌入
        self.c_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 上下文词嵌入
        self.w_bias = nn.Embedding(vocab_size, 1)  # 中心词偏置
        self.c_bias = nn.Embedding(vocab_size, 1)  # 上下文词偏置
        nn.init.xavier_uniform_(self.w_embeddings.weight)
        nn.init.xavier_uniform_(self.c_embeddings.weight)

    def forward(self, center, context, co_occur):
        w_emb = self.w_embeddings(center)
        c_emb = self.c_embeddings(context)
        w_bias = self.w_bias(center).squeeze()
        c_bias = self.c_bias(context).squeeze()
        dot_product = (w_emb * c_emb).sum(dim=1)
        loss = (dot_product + w_bias + c_bias - torch.log(co_occur + 1e-8)) ** 2
        return loss.mean()


# 初始化模型
embedding_dim = 10
model = GloVe(len(vocab), embedding_dim)

# ========== 3. 训练 GloVe ==========
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100

# 转换数据
co_occurrence_tensor = torch.tensor(X.data, dtype=torch.float)
pairs = list(zip(X.row, X.col, co_occurrence_tensor))

for epoch in range(num_epochs):
    total_loss = 0
    np.random.shuffle(pairs)
    for center, context, co_occur in pairs:
        optimizer.zero_grad()
        loss = model(
            torch.tensor([center], dtype=torch.long),
            torch.tensor([context], dtype=torch.long),
            torch.tensor([co_occur], dtype=torch.float)  # 修正数据类型
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

# ========== 4. 获取词向量 ==========
word_vectors = model.w_embeddings.weight.data.numpy()


# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):
    if word not in word2idx:
        return "单词不在词汇表中"

    word_vec = word_vectors[word2idx[word]].reshape(1, -1)
    similarities = np.dot(word_vectors, word_vec.T).squeeze()
    similar_idx = similarities.argsort()[::-1][1:top_n + 1]
    return [(idx2word[idx], similarities[idx]) for idx in similar_idx]


# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

数据预处理

  • 分词 (使用 jieba.cut()
  • 构建共现矩阵(计算窗口内的单词共现频率)
  • 使用稀疏矩阵存储(提高计算效率)

GloVe 模型

  • Embedding 训练词向量(中心词和上下文词分开)
  • Bias 变量 用于调整预测值
  • 损失函数 最小化 log(共现次数) 与词向量点积的差值

计算词向量相似度

  • 使用 cosine similarity
  • 找出 top_n 最相似的单词
相关推荐
B站计算机毕业设计之家14 分钟前
深度学习:YOLOv8人体行为动作识别检测系统 行为识别检测识系统 act-dataset数据集 pyqt5 机器学习✅
人工智能·python·深度学习·qt·yolo·机器学习·计算机视觉
on_pluto_15 分钟前
GAN生成对抗网络学习-例子:生成逼真手写数字图
人工智能·深度学习·神经网络·学习·算法·机器学习·生成对抗网络
机器之心22 分钟前
打造图像编辑领域的ImageNet?苹果用Nano Banana开源了一个超大数据集
人工智能·openai
渡我白衣31 分钟前
AI 应用层革命(一)——软件的终结与智能体的崛起
人工智能·opencv·机器学习·语言模型·数据挖掘·人机交互·集成学习
weixin_4296302637 分钟前
文献10.3 多视图变分深度学习及其在实际室内定位中的应用
人工智能·深度学习
墨利昂42 分钟前
Pytorch常用API(ML和DL)
人工智能·pytorch·python
SunnyDays10111 小时前
Python 裁剪 PDF 教程:轻松裁剪页面并导出为图片
python·pdf裁剪·裁剪pdf页面·裁切pdf
JustNow_Man1 小时前
Cline插件中clinerules的选择机制
python
刘孬孬沉迷学习1 小时前
AI+通信+多模态应用分类与核心内容总结
人工智能·机器学习·分类·数据挖掘·信息与通信
Allenlzcoder1 小时前
掌握机器学习算法及其关键超参数
人工智能·机器学习·超参数