pytorch基于GloVe实现的词嵌入

PyTorch 实现 GloVe(Global Vectors for Word Representation) 的完整代码,使用 中文语料 进行训练,包括 共现矩阵构建、模型定义、训练和测试


1. GloVe 介绍

基于词的共现信息 (不像 Word2Vec 使用滑动窗口预测)

适合较大规模的数据 (比 Word2Vec 更稳定)
学习出的词向量能捕捉语义信息(如类比关系)

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
from scipy.sparse import coo_matrix

# ========== 1. 数据预处理 ==========
corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
]

# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

# 计算共现矩阵
window_size = 2
co_occurrence = Counter()

for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
        center_word = indices[center_idx]
        for offset in range(-window_size, window_size + 1):
            context_idx = center_idx + offset
            if 0 <= context_idx < len(indices) and context_idx != center_idx:
                context_word = indices[context_idx]
                co_occurrence[(center_word, context_word)] += 1

# 转换为稀疏矩阵
rows, cols, values = zip(*[(c[0], c[1], v) for c, v in co_occurrence.items()])
X = coo_matrix((values, (rows, cols)), shape=(len(vocab), len(vocab)))


# ========== 2. 定义 GloVe 模型 ==========
class GloVe(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(GloVe, self).__init__()
        self.w_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 中心词嵌入
        self.c_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 上下文词嵌入
        self.w_bias = nn.Embedding(vocab_size, 1)  # 中心词偏置
        self.c_bias = nn.Embedding(vocab_size, 1)  # 上下文词偏置
        nn.init.xavier_uniform_(self.w_embeddings.weight)
        nn.init.xavier_uniform_(self.c_embeddings.weight)

    def forward(self, center, context, co_occur):
        w_emb = self.w_embeddings(center)
        c_emb = self.c_embeddings(context)
        w_bias = self.w_bias(center).squeeze()
        c_bias = self.c_bias(context).squeeze()
        dot_product = (w_emb * c_emb).sum(dim=1)
        loss = (dot_product + w_bias + c_bias - torch.log(co_occur + 1e-8)) ** 2
        return loss.mean()


# 初始化模型
embedding_dim = 10
model = GloVe(len(vocab), embedding_dim)

# ========== 3. 训练 GloVe ==========
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100

# 转换数据
co_occurrence_tensor = torch.tensor(X.data, dtype=torch.float)
pairs = list(zip(X.row, X.col, co_occurrence_tensor))

for epoch in range(num_epochs):
    total_loss = 0
    np.random.shuffle(pairs)
    for center, context, co_occur in pairs:
        optimizer.zero_grad()
        loss = model(
            torch.tensor([center], dtype=torch.long),
            torch.tensor([context], dtype=torch.long),
            torch.tensor([co_occur], dtype=torch.float)  # 修正数据类型
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

# ========== 4. 获取词向量 ==========
word_vectors = model.w_embeddings.weight.data.numpy()


# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):
    if word not in word2idx:
        return "单词不在词汇表中"

    word_vec = word_vectors[word2idx[word]].reshape(1, -1)
    similarities = np.dot(word_vectors, word_vec.T).squeeze()
    similar_idx = similarities.argsort()[::-1][1:top_n + 1]
    return [(idx2word[idx], similarities[idx]) for idx in similar_idx]


# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

数据预处理

  • 分词 (使用 jieba.cut()
  • 构建共现矩阵(计算窗口内的单词共现频率)
  • 使用稀疏矩阵存储(提高计算效率)

GloVe 模型

  • Embedding 训练词向量(中心词和上下文词分开)
  • Bias 变量 用于调整预测值
  • 损失函数 最小化 log(共现次数) 与词向量点积的差值

计算词向量相似度

  • 使用 cosine similarity
  • 找出 top_n 最相似的单词
相关推荐
smith成长之旅几秒前
07 | Mem0 框架分析:三路信号融合——语义 + BM25 + Entity Boost 的混合检索
python·算法
Cho1yon3 分钟前
【AI Agent 第十期:Claude Code 完全配置指南:三系统一步到位,AI编程助手轻松上手】
人工智能·ai编程
数据皮皮侠AI8 分钟前
上市公司耐心资本数据(2010-2025)
大数据·人工智能·笔记·能源·1024程序员节
陕西企来客8 分钟前
陕西 KNIT 可信知识网络构建模块对于 GEO 优化行业的影响深度调查:企来客科技技术落地真相揭示
大数据·人工智能
追光者♂10 分钟前
【测评系列5】CSDN AI数字营销实测体验官——Claude 大模型深度评测:从参数解析到实战边界
人工智能·ai·大模型·大语言模型·claude·模型幻觉·架构参数
yubo050915 分钟前
计算机视觉第七课:颜色追踪(只框红色 / 蓝色 / 绿色物体)
人工智能·opencv·计算机视觉
编码小哥17 分钟前
OpenCV去噪算法实战:中值滤波与双边滤波应用
人工智能·opencv·计算机视觉
zhangshuang-peta24 分钟前
MCP 如何重新定义 Skill:从“能力函数”变成“可治理行为”
人工智能·ai·ai agent·mcp·peta
yubo050926 分钟前
计算机视觉第六课:打开摄像头,实时框出物体
人工智能·opencv·计算机视觉
荣码28 分钟前
【Python知识详解】变量与数据类型:深入理解 Python 的数据世界
python