python实现skip-gram(跳词)示例

文章目录

什么是跳词?

一句话,就是用中心词,去预测它周围的词。它是 Word2Vec 里最常用的一种训练方式。

示例

1、安装依赖

python 复制代码
pip install matplotlib # 其他torch等依赖早就安装了

2、创建python文件skip_gram_demo.py,代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import Counter

# ==========================================
# 1. 数据准备与预处理
# ==========================================
# 一个简单的微型语料库
corpus = """
    deep learning is powerful
    machine learning is a subset of artificial intelligence
    deep learning models are inspired by the brain
    natural language processing uses deep learning
"""

# 文本清洗与分词
words = corpus.lower().split()

# 构建词汇表 (Word -> Index)
vocab = list(set(words))
word_to_idx = {w: i for i, w in enumerate(vocab)}
idx_to_word = {i: w for i, w in enumerate(vocab)}
vocab_size = len(vocab)

print(f"词汇表大小: {vocab_size}")
print(f"词汇表: {vocab}")


# 生成训练数据 (Skip-gram: 输入中心词 -> 输出上下文词)
def create_dataloader(words, word_to_idx, window_size=2):
    inputs = []
    targets = []

    for i in range(1, len(words) - 1):
        center_word = words[i]
        center_idx = word_to_idx[center_word]

        # 获取上下文窗口
        # 比如 window_size=2,则取前后各2个词
        for j in range(i - window_size, i + window_size + 1):
            if j != i and 0 <= j < len(words):
                context_word = words[j]
                context_idx = word_to_idx[context_word]

                inputs.append(center_idx)
                targets.append(context_idx)

    return torch.tensor(inputs, dtype=torch.long), torch.tensor(targets, dtype=torch.long)


inputs, targets = create_dataloader(words, word_to_idx, window_size=2)


# ==========================================
# 2. 定义 Skip-gram 模型
# ==========================================
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        # 中心词嵌入层 (W)
        self.w_in = nn.Embedding(vocab_size, embedding_dim)
        # 上下文词嵌入层 (W')
        self.w_out = nn.Embedding(vocab_size, embedding_dim)

        # 初始化权重
        nn.init.xavier_uniform_(self.w_in.weight)
        nn.init.xavier_uniform_(self.w_out.weight)

    def forward(self, x):
        # x: (batch_size,)
        # 获取中心词的向量
        embeds = self.w_in(x)  # (batch_size, embedding_dim)
        return embeds

    def loss(self, x, y):
        # x: 中心词索引, y: 上下文词索引
        # 1. 获取中心词向量
        v_center = self.w_in(x)  # (batch_size, dim)
        # 2. 获取上下文词向量
        v_context = self.w_out(y)  # (batch_size, dim)

        # 3. 计算点积 (相似度)
        # 这里的逻辑是:点积越大,概率越大
        score = torch.sum(torch.mul(v_center, v_context), dim=1)  # (batch_size,)

        # 4. 使用负对数似然损失 (简化版,未包含负采样)
        # 实际大规模训练中通常配合 Negative Sampling 使用
        # 这里为了演示简单,直接最大化目标词的概率
        loss = -torch.mean(score)
        return loss


# ==========================================
# 3. 训练模型
# ==========================================
embedding_dim = 10  # 词向量维度
learning_rate = 0.01
epochs = 1000

model = SkipGramModel(vocab_size, embedding_dim)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

print("\n开始训练...")
for epoch in range(epochs):
    optimizer.zero_grad()

    # 前向传播
    loss = model.loss(inputs, targets)

    # 反向传播
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 200 == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

# ==========================================
# 4. 结果可视化与测试
# ==========================================
print("\n训练完成!查看词向量相似度...")

# 获取嵌入权重
embeddings = model.w_in.weight.data.numpy()


# 简单的余弦相似度计算
def cosine_similarity(w1, w2):
    return np.dot(w1, w2) / (np.linalg.norm(w1) * np.linalg.norm(w2))


# 测试几个词
test_words = ["learning", "deep", "artificial", "brain"]

import numpy as np

for w1 in test_words:
    if w1 in word_to_idx:
        vec1 = embeddings[word_to_idx[w1]]
        print(f"\n与 '{w1}' 最相似的词:")

        similarities = []
        for w2 in vocab:
            if w1 != w2:
                vec2 = embeddings[word_to_idx[w2]]
                sim = cosine_similarity(vec1, vec2)
                similarities.append((w2, sim))

        # 排序并打印前3个
        similarities.sort(key=lambda x: x[1], reverse=True)
        for word, score in similarities[:3]:
            print(f"    {word}: {score:.4f}")

# 2D 可视化 (PCA 降维)
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
reduced_embeds = pca.fit_transform(embeddings)

plt.figure(figsize=(10, 8))
for i, word in enumerate(vocab):
    plt.scatter(reduced_embeds[i, 0], reduced_embeds[i, 1])
    plt.annotate(word, (reduced_embeds[i, 0], reduced_embeds[i, 1]))

plt.title("Word Embeddings Visualization (PCA)")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.grid(True)
plt.show()

输出结果:

python 复制代码
词汇表大小: 20
词汇表: ['artificial', 'inspired', 'brain', 'natural', 'is', 'are', 'learning', 'by', 'machine', 'powerful', 'processing', 'language', 'a', 'intelligence', 'uses', 'subset', 'deep', 'models', 'the', 'of']

开始训练...
Epoch 200, Loss: -0.0312
Epoch 400, Loss: -0.0661
Epoch 600, Loss: -0.1041
Epoch 800, Loss: -0.1467
Epoch 1000, Loss: -0.1957

训练完成!查看词向量相似度...

与 'learning' 最相似的词:
    inspired: 0.6657
    are: 0.4793
    is: 0.4745

与 'deep' 最相似的词:
    machine: 0.6026
    intelligence: 0.5229
    processing: 0.4629

与 'artificial' 最相似的词:
    is: 0.5218
    by: 0.5195
    the: 0.5013

与 'brain' 最相似的词:
    subset: 0.2076
    powerful: 0.1457
    language: 0.0755

解读:

给了一堆杂乱的文字,它居然将这些词分出了远近关系。

成功了。

相关推荐
maqr_11020 小时前
SQL如何快速提取分组中最晚时间点数据_结合窗口函数实现
jvm·数据库·python
Shorasul20 小时前
mysql如何限制特定表的最大存储空间_通过ALTER TABLE设置MAX_ROWS
jvm·数据库·python
214396520 小时前
如何存储MongoDB的爬虫抓取数据_动态字段与无模式宽容度.txt
jvm·数据库·python
riNt PTIP20 小时前
在21世纪的我用C语言探寻世界本质——字符函数和字符串函数(2)
c语言·开发语言
m0_7489203620 小时前
CSS如何实现网格内绝对定位_利用Grid的relative属性层级控制
jvm·数据库·python
qq_3422958220 小时前
golang如何优化磁盘IO性能_golang磁盘IO性能优化思路
jvm·数据库·python
weixin_4249993620 小时前
MySQL中如何使用CAST实现类型转换_MySQL数据类型转换技巧
jvm·数据库·python
2301_7775993720 小时前
SQL如何高效提取大表前几行:分页查询与OFFSET优化.txt
jvm·数据库·python
2301_8135995520 小时前
CSS如何实现纯CSS树状目录结构_利用-checked与递归思维构建交互节点
jvm·数据库·python
XS03010620 小时前
Java 基础(六)封装类 Object类
java·jvm·python