pytorch基于FastText实现词嵌入

FastText 是 Facebook AI Research 提出的 改进版 Word2Vec ,可以: ✅ 利用 n-grams 处理未登录词
比 Word2Vec 更快、更准确
适用于中文等形态丰富的语言

完整的 PyTorch FastText 代码(基于中文语料),包含:

  • 数据预处理(分词 + n-grams)

  • 模型定义

  • 训练

  • 测试

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np
    import jieba
    from collections import Counter
    import random

    ========== 1. 数据预处理 ==========

    corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
    ]

    分词

    tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]

    构建 n-grams

    def generate_ngrams(words, n=3):
    ngrams = []
    for word in words:
    ngrams += [word[i:i + n] for i in range(len(word) - n + 1)]
    return ngrams

    生成 n-grams 词表

    all_ngrams = set()
    for sentence in tokenized_corpus:
    for word in sentence:
    all_ngrams.update(generate_ngrams(word))

    构建词汇表

    vocab = set(word for sentence in tokenized_corpus for word in sentence) | all_ngrams
    word2idx = {word: idx for idx, word in enumerate(vocab)}
    idx2word = {idx: word for word, idx in word2idx.items()}

    构建训练数据(CBOW 方式)

    window_size = 2
    data = []

    for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
    context = []
    for offset in range(-window_size, window_size + 1):
    context_idx = center_idx + offset
    if 0 <= context_idx < len(indices) and context_idx != center_idx:
    context.append(indices[context_idx])
    if context:
    data.append((context, indices[center_idx])) # (上下文, 目标词)

    ========== 2. 定义 FastText 模型 ==========

    class FastText(nn.Module):
    def init(self, vocab_size, embedding_dim):
    super(FastText, self).init()
    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear = nn.Linear(embedding_dim, vocab_size)

    复制代码
      def forward(self, context):
          context_vec = self.embeddings(context).mean(dim=1)  # 平均上下文向量
          output = self.linear(context_vec)
          return output

    初始化模型

    embedding_dim = 10
    model = FastText(len(vocab), embedding_dim)

    ========== 3. 训练 FastText ==========

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    num_epochs = 100

    for epoch in range(num_epochs):
    total_loss = 0
    random.shuffle(data)

    复制代码
      for context, target in data:
          context = torch.tensor([context], dtype=torch.long)
          target = torch.tensor([target], dtype=torch.long)
    
          optimizer.zero_grad()
          output = model(context)
          loss = criterion(output, target)
          loss.backward()
          optimizer.step()
    
          total_loss += loss.item()
    
      if (epoch + 1) % 10 == 0:
          print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

    ========== 4. 获取词向量 ==========

    word_vectors = model.embeddings.weight.data.numpy()

    ========== 5. 计算相似度 ==========

    def most_similar(word, top_n=3):
    if word not in word2idx:
    return "单词不在词汇表中"

    复制代码
      word_vec = word_vectors[word2idx[word]].reshape(1, -1)
      similarities = np.dot(word_vectors, word_vec.T).squeeze()
      similar_idx = similarities.argsort()[::-1][1:top_n + 1]
      return [(idx2word[idx], similarities[idx]) for idx in similar_idx]

    测试

    test_words = ["深度", "学习", "人工智能"]
    for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

1. 生成 n-grams

  • FastText 处理单词的 子词单元(n-grams)
  • 例如 "学习" 会生成 ["学习", "习学", "学"]
  • 这样即使遇到未登录词也能拆分为 n-grams 计算

2. 训练数据

  • 使用 CBOW(上下文预测中心词)

  • 窗口大小 = 2 ,即:

    复制代码
    句子: ["深度", "学习", "是", "人工智能"]
    示例: (["深度", "是"], "学习")

3. FastText 模型

  • 词向量是 n-grams 词向量的平均值
  • 计算公式:
  • 这样,即使单词没见过,也能用它的 n-grams 计算词向量!

4. 计算相似度

  • cosine similarity 找出最相似的单词
  • FastText 比 Word2Vec 更准确,因为它能利用 n-grams 捕捉词的语义信息
特性 FastText Word2Vec GloVe
原理 预测中心词 + n-grams 预测中心词或上下文 统计词共现信息
未登录词处理 可处理 无法处理 无法处理
训练速度
适合领域 中文、罕见词 传统 NLP 大规模数据
相关推荐
jiushiapwojdap几秒前
Antigravity Awesome Skills:1527+ AI 编程助手的可安装技能库
人工智能·其他
顾北顾2 分钟前
多头注意力机制
人工智能·深度学习·算法
hujinyuan2016017 分钟前
2025年12月中国电子学会青少年机器人技术等级考试试卷(二级) 真题+答案
人工智能·算法·机器人
码农小白AI23 分钟前
采购合同与来料证书对标校验,IACheck联动AI报告审核通审Agent版自动识别指标不符单据
人工智能
CTA量化套保1 小时前
期货量化程序 time.sleep 卡死:天勤单线程与 deadline 替代
python·区块链
元岳数字人小元1 小时前
AI 数字人开发公司浅谈 虚拟数字人打造景区新服务
人工智能·人机交互·交互
哦哦~9211 小时前
AI赋能生物医学:从临床数据到药物分子性质预测实战培
人工智能·生物医学·药物分子
GIS数据转换器1 小时前
城市排水生命线安全运行监测平台深度解析
java·运维·人工智能·python·安全·数据挖掘·无人机
虫无涯1 小时前
本地离线大模型实战:Ollama + Llama 3.1 8B 全流程部署(适配VSCode Continue代码助手)
人工智能
Rocky Ding*1 小时前
Latent Consistency Models:一篇读懂扩散模型的少步生成核心基础知识
人工智能·深度学习·机器学习·ai作画·stable diffusion·aigc·ai-native