使用 PyTorch 实现 Word2Vec 中Skip-gram 模型

首先创建了一个使用 Word2VecDataset 类自定义的数据集,用于生成训练数据。然后,定义了 Skip-gram 模型,并使用交叉熵损失函数和 Adam 优化器进行训练。

在每个训练周期中,遍历数据加载器,对每个批次进行前向传播、计算损失、反向传播和权重更新。最后,得到训练得到的词向量,并可以使用 word_vector 来获取特定单词的词向量表示。

确保在运行之前安装 PyTorch,可以使用 pip install torch 来安装它。请注意,如果可用的话,代码将在 GPU 上运行。如果没有 GPU,请将 .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 部分删除,并在 CPU 上运行。

以下是使用 PyTorch 实现 Skip-gram 模型的示例代码:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Hyperparameters
embedding_dim = 100
window_size = 2
learning_rate = 0.001
epochs = 100
batch_size = 32

# Example corpus
corpus = [['I', 'enjoy', 'playing', 'football', 'with', 'my', 'friends'],
          ['We', 'like', 'to', 'play', 'tennis', 'on', 'weekends'],
          ['She', 'is', 'a', 'good', 'dancer']]

# Create vocabulary
vocab = list(set([word for sentence in corpus for word in sentence]))
vocab_size = len(vocab)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}

# Generate training data
class Word2VecDataset(Dataset):
    def __init__(self, corpus, word2idx):
        self.data = []
        for sentence in corpus:
            word_indices = [word2idx[word] for word in sentence]
            for center_word_idx, center_word in enumerate(word_indices):
                for context_word_idx in range(max(0, center_word_idx - window_size), min(center_word_idx + window_size + 1, len(word_indices))):
                    if context_word_idx != center_word_idx:
                        context_word = word_indices[context_word_idx]
                        self.data.append((center_word, context_word))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = Word2VecDataset(corpus, word2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define Skip-gram model
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, center_word):
        embedded = self.embedding(center_word)
        output = self.linear(embedded)
        return output

model = SkipGramModel(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training
for epoch in range(epochs):
    running_loss = 0.0
    for i, (center_word, context_word) in enumerate(dataloader):
        optimizer.zero_grad()
        
        center_word = center_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        context_word = context_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        
        output = model(center_word)
        loss = criterion(output, context_word)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    average_loss = running_loss / len(dataloader)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}')

# Get trained word embeddings
trained_embeddings = model.embedding.weight.data.numpy()

# Example usage - Getting word vector for a word
word = 'football'
word_vector = trained_embeddings[word2idx[word]]
print(f"Word vector for '{word}': {word_vector}")

运行结果如下:

Epoch 1/100, Loss: 3.1324

Epoch 2/100, Loss: 3.0791

Epoch 3/100, Loss: 2.9902

Epoch 4/100, Loss: 2.9392

Epoch 5/100, Loss: 2.8870

Epoch 6/100, Loss: 2.8166

Epoch 7/100, Loss: 2.7615

Epoch 8/100, Loss: 2.7017

Epoch 9/100, Loss: 2.6500

Epoch 10/100, Loss: 2.5993

Epoch 11/100, Loss: 2.5496

Epoch 12/100, Loss: 2.5013

Epoch 13/100, Loss: 2.4621

Epoch 14/100, Loss: 2.4079

Epoch 15/100, Loss: 2.3660

Epoch 16/100, Loss: 2.3229

Epoch 17/100, Loss: 2.2795

Epoch 18/100, Loss: 2.2398

Epoch 19/100, Loss: 2.1998

Epoch 20/100, Loss: 2.1582

Epoch 21/100, Loss: 2.1278

Epoch 22/100, Loss: 2.1023

Epoch 23/100, Loss: 2.0569

Epoch 24/100, Loss: 2.0245

Epoch 25/100, Loss: 1.9936

Epoch 26/100, Loss: 1.9639

Epoch 27/100, Loss: 1.9344

Epoch 28/100, Loss: 1.9137

Epoch 29/100, Loss: 1.8888

Epoch 30/100, Loss: 1.8586

Epoch 31/100, Loss: 1.8352

Epoch 32/100, Loss: 1.8200

Epoch 33/100, Loss: 1.7815

Epoch 34/100, Loss: 1.7685

Epoch 35/100, Loss: 1.7531

Epoch 36/100, Loss: 1.7209

Epoch 37/100, Loss: 1.7049

Epoch 38/100, Loss: 1.6881

Epoch 39/100, Loss: 1.6775

Epoch 40/100, Loss: 1.6517

Epoch 41/100, Loss: 1.6390

Epoch 42/100, Loss: 1.6238

Epoch 43/100, Loss: 1.6077

Epoch 44/100, Loss: 1.5939

Epoch 45/100, Loss: 1.5745

Epoch 46/100, Loss: 1.5703

Epoch 47/100, Loss: 1.5574

Epoch 48/100, Loss: 1.5458

Epoch 49/100, Loss: 1.5308

Epoch 50/100, Loss: 1.5215

Epoch 51/100, Loss: 1.5122

Epoch 52/100, Loss: 1.4988

Epoch 53/100, Loss: 1.4958

Epoch 54/100, Loss: 1.4773

Epoch 55/100, Loss: 1.4746

Epoch 56/100, Loss: 1.4618

Epoch 57/100, Loss: 1.4560

Epoch 58/100, Loss: 1.4506

Epoch 59/100, Loss: 1.4380

Epoch 60/100, Loss: 1.4266

Epoch 61/100, Loss: 1.4257

Epoch 62/100, Loss: 1.4148

Epoch 63/100, Loss: 1.4090

Epoch 64/100, Loss: 1.4070

Epoch 65/100, Loss: 1.3940

Epoch 66/100, Loss: 1.3890

Epoch 67/100, Loss: 1.3846

Epoch 68/100, Loss: 1.3813

Epoch 69/100, Loss: 1.3738

Epoch 70/100, Loss: 1.3717

Epoch 71/100, Loss: 1.3681

Epoch 72/100, Loss: 1.3594

Epoch 73/100, Loss: 1.3593

Epoch 74/100, Loss: 1.3504

Epoch 75/100, Loss: 1.3447

Epoch 76/100, Loss: 1.3439

Epoch 77/100, Loss: 1.3397

Epoch 78/100, Loss: 1.3315

Epoch 79/100, Loss: 1.3260

Epoch 80/100, Loss: 1.3253

Epoch 81/100, Loss: 1.3229

Epoch 82/100, Loss: 1.3215

Epoch 83/100, Loss: 1.3148

Epoch 84/100, Loss: 1.3160

Epoch 85/100, Loss: 1.3072

Epoch 86/100, Loss: 1.3105

Epoch 87/100, Loss: 1.3104

Epoch 88/100, Loss: 1.3018

Epoch 89/100, Loss: 1.2912

Epoch 90/100, Loss: 1.2950

Epoch 91/100, Loss: 1.2938

Epoch 92/100, Loss: 1.2951

Epoch 93/100, Loss: 1.2859

Epoch 94/100, Loss: 1.2902

Epoch 95/100, Loss: 1.2840

Epoch 96/100, Loss: 1.2748

Epoch 97/100, Loss: 1.2840

Epoch 98/100, Loss: 1.2763

Epoch 99/100, Loss: 1.2772

Epoch 100/100, Loss: 1.2746

Word vector for 'football':

[-1.2727762 0.8401019 -0.5115612 2.0667355 1.1854529 -0.7444803

-1.9658612 -1.0488677 0.98938674 -1.1675086 1.582392 1.7414839

-0.4892138 -1.2149098 0.15343344 -1.8318586 0.41794038 0.25481498

0.6008032 -0.23904797 0.80143225 -1.0495795 -1.0174142 -0.01827855

2.7477944 -0.9574399 1.025569 2.4843202 -0.2796719 -0.4390253

-1.4423424 -1.8073392 0.1897556 0.90259725 2.7565296 -0.28331178

-1.8443514 0.77545553 -1.0289538 0.71483964 1.1801128 -0.22635305

0.5960759 0.6690206 -1.9100318 1.2388043 -0.68522704 0.92120373

1.0252377 -1.4376261 -0.6595934 0.31699112 0.6751458 0.99656415

0.40565705 -1.0904227 -0.3513346 -0.66078615 1.1834346 -1.0899751

-1.4925232 -0.30818892 1.4249563 0.06006899 -3.2386255 0.96192694

-1.1045157 0.5540482 -1.5388466 -0.8721646 1.1221852 1.6488599

0.44869688 1.1519432 -1.4588032 -0.04230021 -0.33113605 1.1316347

-0.7425484 -0.11400439 0.37237874 -0.34573358 0.4140474 -0.04413145

0.6157635 -1.0094129 -1.2208599 -0.7154122 0.9412035 0.9452426

-0.0973389 -0.23566085 0.34300375 -0.95858365 0.8764276 -0.5669889

-1.933235 0.22371146 1.6641699 1.3258857 ]

相关推荐
小于小于大橙子3 小时前
视觉SLAM数学基础
人工智能·数码相机·自动化·自动驾驶·几何学
封步宇AIGC5 小时前
量化交易系统开发-实时行情自动化交易-3.4.2.Okex行情交易数据
人工智能·python·机器学习·数据挖掘
封步宇AIGC5 小时前
量化交易系统开发-实时行情自动化交易-2.技术栈
人工智能·python·机器学习·数据挖掘
陌上阳光5 小时前
动手学深度学习68 Transformer
人工智能·深度学习·transformer
OpenI启智社区5 小时前
共筑开源技术新篇章 | 2024 CCF中国开源大会盛大开幕
人工智能·开源·ccf中国开源大会·大湾区
AI服务老曹5 小时前
建立更及时、更有效的安全生产优化提升策略的智慧油站开源了
大数据·人工智能·物联网·开源·音视频
YRr YRr5 小时前
PyTorch:torchvision中的dataset的使用
人工智能
love_and_hope6 小时前
Pytorch学习--神经网络--完整的模型训练套路
人工智能·pytorch·python·深度学习·神经网络·学习
思通数据6 小时前
AI与OCR:数字档案馆图像扫描与文字识别技术实现与项目案例
大数据·人工智能·目标检测·计算机视觉·自然语言处理·数据挖掘·ocr
兔老大的胡萝卜6 小时前
关于 3D Engine Design for Virtual Globes(三维数字地球引擎设计)
人工智能·3d