pytorch基于FastText实现词嵌入

FastText 是 Facebook AI Research 提出的 改进版 Word2Vec ,可以: ✅ 利用 n-grams 处理未登录词
比 Word2Vec 更快、更准确
适用于中文等形态丰富的语言

完整的 PyTorch FastText 代码(基于中文语料),包含:

  • 数据预处理(分词 + n-grams)

  • 模型定义

  • 训练

  • 测试

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    import jieba
    from collections import Counter
    import random

    ========== 1. 数据预处理 ==========

    corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
    ]

    分词

    tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]

    构建 n-grams

    def generate_ngrams(words, n=3):
    ngrams = []
    for word in words:
    ngrams += [word[i:i + n] for i in range(len(word) - n + 1)]
    return ngrams

    生成 n-grams 词表

    all_ngrams = set()
    for sentence in tokenized_corpus:
    for word in sentence:
    all_ngrams.update(generate_ngrams(word))

    构建词汇表

    vocab = set(word for sentence in tokenized_corpus for word in sentence) | all_ngrams
    word2idx = {word: idx for idx, word in enumerate(vocab)}
    idx2word = {idx: word for word, idx in word2idx.items()}

    构建训练数据(CBOW 方式)

    window_size = 2
    data = []

    for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
    context = []
    for offset in range(-window_size, window_size + 1):
    context_idx = center_idx + offset
    if 0 <= context_idx < len(indices) and context_idx != center_idx:
    context.append(indices[context_idx])
    if context:
    data.append((context, indices[center_idx])) # (上下文, 目标词)

    ========== 2. 定义 FastText 模型 ==========

    class FastText(nn.Module):
    def init(self, vocab_size, embedding_dim):
    super(FastText, self).init()
    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear = nn.Linear(embedding_dim, vocab_size)

    复制代码
      def forward(self, context):
          context_vec = self.embeddings(context).mean(dim=1)  # 平均上下文向量
          output = self.linear(context_vec)
          return output

    初始化模型

    embedding_dim = 10
    model = FastText(len(vocab), embedding_dim)

    ========== 3. 训练 FastText ==========

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    num_epochs = 100

    for epoch in range(num_epochs):
    total_loss = 0
    random.shuffle(data)

    复制代码
      for context, target in data:
          context = torch.tensor([context], dtype=torch.long)
          target = torch.tensor([target], dtype=torch.long)
    
          optimizer.zero_grad()
          output = model(context)
          loss = criterion(output, target)
          loss.backward()
          optimizer.step()
    
          total_loss += loss.item()
    
      if (epoch + 1) % 10 == 0:
          print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

    ========== 4. 获取词向量 ==========

    word_vectors = model.embeddings.weight.data.numpy()

    ========== 5. 计算相似度 ==========

    def most_similar(word, top_n=3):
    if word not in word2idx:
    return "单词不在词汇表中"

    复制代码
      word_vec = word_vectors[word2idx[word]].reshape(1, -1)
      similarities = np.dot(word_vectors, word_vec.T).squeeze()
      similar_idx = similarities.argsort()[::-1][1:top_n + 1]
      return [(idx2word[idx], similarities[idx]) for idx in similar_idx]

    测试

    test_words = ["深度", "学习", "人工智能"]
    for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

1. 生成 n-grams

  • FastText 处理单词的 子词单元(n-grams)
  • 例如 "学习" 会生成 ["学习", "习学", "学"]
  • 这样即使遇到未登录词也能拆分为 n-grams 计算

2. 训练数据

  • 使用 CBOW(上下文预测中心词)

  • 窗口大小 = 2 ,即:

    复制代码
    句子: ["深度", "学习", "是", "人工智能"]
    示例: (["深度", "是"], "学习")

3. FastText 模型

  • 词向量是 n-grams 词向量的平均值
  • 计算公式:
  • 这样,即使单词没见过,也能用它的 n-grams 计算词向量!

4. 计算相似度

  • cosine similarity 找出最相似的单词
  • FastText 比 Word2Vec 更准确,因为它能利用 n-grams 捕捉词的语义信息
特性 FastText Word2Vec GloVe
原理 预测中心词 + n-grams 预测中心词或上下文 统计词共现信息
未登录词处理 可处理 无法处理 无法处理
训练速度
适合领域 中文、罕见词 传统 NLP 大规模数据
相关推荐
IT_陈寒2 小时前
React 18实战:7个被低估的Hooks技巧让你的开发效率提升50%
前端·人工智能·后端
数据智能老司机3 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
逛逛GitHub3 小时前
飞书多维表“独立”了!功能强大的超出想象。
人工智能·github·产品
机器之心3 小时前
刚刚,DeepSeek-R1论文登上Nature封面,通讯作者梁文锋
人工智能·openai
数据智能老司机4 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机4 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机4 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i4 小时前
drf初步梳理
python·django
每日AI新事件4 小时前
python的异步函数
python
这里有鱼汤6 小时前
miniQMT下载历史行情数据太慢怎么办?一招提速10倍!
前端·python