pytorch基于FastText实现词嵌入

FastText 是 Facebook AI Research 提出的 改进版 Word2Vec ,可以: ✅ 利用 n-grams 处理未登录词
比 Word2Vec 更快、更准确
适用于中文等形态丰富的语言

完整的 PyTorch FastText 代码(基于中文语料),包含:

  • 数据预处理(分词 + n-grams)

  • 模型定义

  • 训练

  • 测试

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    import jieba
    from collections import Counter
    import random

    ========== 1. 数据预处理 ==========

    corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
    ]

    分词

    tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]

    构建 n-grams

    def generate_ngrams(words, n=3):
    ngrams = []
    for word in words:
    ngrams += [word[i:i + n] for i in range(len(word) - n + 1)]
    return ngrams

    生成 n-grams 词表

    all_ngrams = set()
    for sentence in tokenized_corpus:
    for word in sentence:
    all_ngrams.update(generate_ngrams(word))

    构建词汇表

    vocab = set(word for sentence in tokenized_corpus for word in sentence) | all_ngrams
    word2idx = {word: idx for idx, word in enumerate(vocab)}
    idx2word = {idx: word for word, idx in word2idx.items()}

    构建训练数据(CBOW 方式)

    window_size = 2
    data = []

    for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
    context = []
    for offset in range(-window_size, window_size + 1):
    context_idx = center_idx + offset
    if 0 <= context_idx < len(indices) and context_idx != center_idx:
    context.append(indices[context_idx])
    if context:
    data.append((context, indices[center_idx])) # (上下文, 目标词)

    ========== 2. 定义 FastText 模型 ==========

    class FastText(nn.Module):
    def init(self, vocab_size, embedding_dim):
    super(FastText, self).init()
    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear = nn.Linear(embedding_dim, vocab_size)

    复制代码
      def forward(self, context):
          context_vec = self.embeddings(context).mean(dim=1)  # 平均上下文向量
          output = self.linear(context_vec)
          return output

    初始化模型

    embedding_dim = 10
    model = FastText(len(vocab), embedding_dim)

    ========== 3. 训练 FastText ==========

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    num_epochs = 100

    for epoch in range(num_epochs):
    total_loss = 0
    random.shuffle(data)

    复制代码
      for context, target in data:
          context = torch.tensor([context], dtype=torch.long)
          target = torch.tensor([target], dtype=torch.long)
    
          optimizer.zero_grad()
          output = model(context)
          loss = criterion(output, target)
          loss.backward()
          optimizer.step()
    
          total_loss += loss.item()
    
      if (epoch + 1) % 10 == 0:
          print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

    ========== 4. 获取词向量 ==========

    word_vectors = model.embeddings.weight.data.numpy()

    ========== 5. 计算相似度 ==========

    def most_similar(word, top_n=3):
    if word not in word2idx:
    return "单词不在词汇表中"

    复制代码
      word_vec = word_vectors[word2idx[word]].reshape(1, -1)
      similarities = np.dot(word_vectors, word_vec.T).squeeze()
      similar_idx = similarities.argsort()[::-1][1:top_n + 1]
      return [(idx2word[idx], similarities[idx]) for idx in similar_idx]

    测试

    test_words = ["深度", "学习", "人工智能"]
    for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

1. 生成 n-grams

  • FastText 处理单词的 子词单元(n-grams)
  • 例如 "学习" 会生成 ["学习", "习学", "学"]
  • 这样即使遇到未登录词也能拆分为 n-grams 计算

2. 训练数据

  • 使用 CBOW(上下文预测中心词)

  • 窗口大小 = 2 ,即:

    复制代码
    句子: ["深度", "学习", "是", "人工智能"]
    示例: (["深度", "是"], "学习")

3. FastText 模型

  • 词向量是 n-grams 词向量的平均值
  • 计算公式:
  • 这样,即使单词没见过,也能用它的 n-grams 计算词向量!

4. 计算相似度

  • cosine similarity 找出最相似的单词
  • FastText 比 Word2Vec 更准确,因为它能利用 n-grams 捕捉词的语义信息
特性 FastText Word2Vec GloVe
原理 预测中心词 + n-grams 预测中心词或上下文 统计词共现信息
未登录词处理 可处理 无法处理 无法处理
训练速度
适合领域 中文、罕见词 传统 NLP 大规模数据
相关推荐
maozexijr2 分钟前
什么是 Flink Pattern
大数据·python·flink
云卓SKYDROID4 分钟前
无人机减震模块运行与技术要点分析!
人工智能·无人机·科普·高科技·减震系统
山北雨夜漫步17 分钟前
机器学习 Day18 Support Vector Machine ——最优美的机器学习算法
人工智能·算法·机器学习
正在走向自律19 分钟前
从0到1吃透卷积神经网络(CNN):原理与实战全解析
人工智能·神经网络·cnn
拓端研究室TRL22 分钟前
Python+AI提示词糖尿病预测融合模型:伯努利朴素贝叶斯、逻辑回归、决策树、随机森林、支持向量机SVM应用
人工智能·python·决策树·随机森林·逻辑回归
winfredzhang27 分钟前
使用Python和Selenium打造一个全网页截图工具
开发语言·python·selenium
mahuifa35 分钟前
(10)python开发经验
开发语言·python
何双新36 分钟前
第8讲、Multi-Head Attention 的核心机制与实现细节
人工智能·transformer
moongoblin37 分钟前
协作赋能-1-制造业生产流程重构
大数据·人工智能·经验分享·制造
穿越光年39 分钟前
MCP实战:在扣子空间用扣子工作流MCP,一句话生成儿童故事rap视频
人工智能·音视频