一、前言:为什么需要词向量(Word Embedding)?
在 NLP 中,计算机无法直接理解文字,必须把文本 → 数字向量。
传统方法:One-Hot、词袋模型(CountVectorizer)
One‑Hot 编码
-
每个词对应一个只有 1 个 1、其余为 0 的向量
-
缺点
- 极度稀疏
- 维度灾难
- 无法表示语义相似性
- 无法捕捉上下文
词袋模型
• CountVectorizer:统计词频
• ngram_range(1,2):1 个词、2 个词组合
• 缺点
◦ 参数空间爆炸,无法处理 N>3 的情况
◦ 不考虑词与词之间的语义联系
◦ 无法表示相似性
维度灾难、向量稀疏、无语义信息
现代方法:词嵌入(Word Embedding)
- 把高维稀疏表示 → 低维稠密向量(如 4960 维 → 300 维)
- 维度一般自定义:如 300 维
- 向量是浮点数,不是 0/1
- 能表示语义相似性
- 解决维度灾难
2013 年 Google 提出 Word2Vec,包含两种模型:
- CBOW(上下文预测中心词)
- Skip-Gram(中心词预测上下文)
本文实现 CBOW
二、CBOW 模型原理
CBOW = Continuous Bag-of-Words(连续词袋模型)
- 输入:某单词的上下文词语
- 输出:该中心词
- 训练完成后,Embedding 层权重 = 词向量
训练流程:
- 输入上下文 one‑hot
- 乘矩阵 WV(V×N)得到低维向量
- 求平均得到 1×N
- 乘矩阵 W'(N×V)
- Softmax 输出概率
- 反向传播更新 WV、W'
- 最终使用的是 WV 矩阵(词向量)
三、CBOW 完整代码
python
# 导入必要库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm # 训练进度条
import numpy as np
# ===================== 1. 超参数设置 =====================
# 上下文窗口大小:取中心词左右各2个词(共4个上下文词)
CONTEXT_SIZE = 2
# 训练语料库(示例英文文本,可替换为中文)
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()
# ===================== 2. 构建词汇表 =====================
# 去重得到词库(保证唯一)
vocab = set(raw_text)
vocab_size = len(vocab) # 词汇表总大小
# 单词 → 索引 映射(模型只能处理数字)
word_to_idx = {word: i for i, word in enumerate(vocab)}
# 索引 → 单词 映射(用于结果解码)
idx_to_word = {i: word for i, word in enumerate(vocab)}
# ===================== 3. 构建训练数据集 =====================
data = []#获取上下文词,将上下文词作为输入,目标词作为输出。构建训练数据集
# 遍历文本,为每个中心词构造(上下文,中心词)样本
for i in range(CONTEXT_SIZE, len(raw_text) - CONTEXT_SIZE):
# 获取中心词左边2个 + 右边2个 词作为上下文
context = (
[raw_text[i - (2 - j)] for j in range(CONTEXT_SIZE)]
+ [raw_text[i + j + 1] for j in range(CONTEXT_SIZE)]
)
# 当前词为标签
target = raw_text[i]
data.append((context, target))#将上下文词和目标词保存到data中
# ===================== 4. 上下文转索引张量 =====================
def make_context_vector(context, word_to_ix):
# 把单词转为对应索引
idxs = [word_to_ix[w] for w in context]
# 转为PyTorch长整型张量
return torch.tensor(idxs, dtype=torch.long)
# 测试:输出第一个样本的上下文索引向量
print(make_context_vector(data[0][0],word_to_idx))
# ===================== 5. 设备配置(GPU/CPU) =====================
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("训练设备:", device)
# ===================== 6. 定义CBOW模型 =====================
class CBOW(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(CBOW, self).__init__()
# 词嵌入层:将词索引 → 低维向量
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
# 隐藏层:特征映射
self.proj = nn.Linear(embedding_dim, 128)
# 输出层:映射到词汇表大小,预测中心词
self.output = nn.Linear(128, vocab_size)
def forward(self, inputs):
# 对上下文词向量求和(也可求平均)
embeds = sum(self.embeddings(inputs)).view(1, -1)
# 隐藏层激活
out = F.relu(self.proj(embeds))
# 输出层
out = self.output(out)
# 对数Softmax,适配NLLLoss损失
nll_prob = F.log_softmax(out, dim=-1)
return nll_prob
# ===================== 7. 初始化模型、优化器、损失 =====================
# 词向量维度设为10(可改为100/200/300)
model = CBOW(vocab_size, 10).to(device)
# Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 损失函数:负对数似然损失
loss_function = nn.NLLLoss()
losses = [] # 记录损失
# ===================== 8. 开始训练 =====================
model.train()
print("开始训练...")
for epoch in tqdm(range(200)): # 训练200轮
total_loss = 0
for context, target in data:
# 数据移到设备
context_vector = make_context_vector(context, word_to_idx).to(device)
target = torch.tensor([word_to_idx[target]]).to(device)
# 前向传播
train_predict = model(context_vector)
loss = loss_function(train_predict, target)
# 反向传播 + 更新参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
losses.append(total_loss)
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
# ===================== 9. 模型测试 =====================
# 给定上下文,预测中心词
context = ['People', 'create', 'to', 'direct']
context_vector = make_context_vector(context, word_to_idx).to(device)
model.eval() # 评估模式
with torch.no_grad():
predict = model(context_vector)
max_idx = predict.argmax(1).item()
print("\n" + "="*50)
print("上下文:", context)
print("模型预测中心词:", idx_to_word[max_idx])
print("="*50)
# ===================== 10. 提取并保存词向量 =====================
# 获取词向量权重
W = model.embeddings.weight.cpu().detach().numpy()
# 构建 单词→词向量 字典
#生成词嵌入字典
word_2_vec ={}
for word in word_to_idx.keys():
word_2_vec[word]=W[word_to_idx[word],:]
print('训练完成!')
np.savez('wordvec实现.npz',file_1=W)
data=np.load('wordvec实现.npz')
print("保存的文件:",data.files)
运行结果:
bash
C:\Users\Dell\AppData\Local\Programs\Python\Python39\python.exe D:\software\Pycharm\自然语言处理\wordvec实现.py
0%| | 0/200 [00:00<?, ?it/s]tensor([ 6, 1, 25, 3])
cpu
Epoch1,Loss:237.8876
Epoch2,Loss:177.3150
2%|▏ | 3/200 [00:00<00:07, 27.76it/s]Epoch3,Loss:137.1353
Epoch4,Loss:104.0334
Epoch5,Loss:77.5424
Epoch6,Loss:57.3692
Epoch7,Loss:42.5198
4%|▎ | 7/200 [00:00<00:06, 29.57it/s]Epoch8,Loss:31.9138
Epoch9,Loss:24.2765
Epoch10,Loss:18.8579
Epoch11,Loss:14.8991
6%|▌ | 11/200 [00:00<00:06, 29.90it/s]Epoch12,Loss:11.9644
Epoch13,Loss:9.7586
7%|▋ | 14/200 [00:00<00:06, 29.40it/s]Epoch14,Loss:8.0582
Epoch15,Loss:6.7580
Epoch16,Loss:5.7170
8%|▊ | 17/200 [00:00<00:06, 28.80it/s]Epoch17,Loss:4.8913
Epoch18,Loss:4.2321
Epoch19,Loss:3.6852
Epoch20,Loss:3.2369
10%|█ | 20/200 [00:00<00:06, 28.31it/s]Epoch21,Loss:2.8590
Epoch22,Loss:2.5440
Epoch23,Loss:2.2742
12%|█▏ | 23/200 [00:00<00:06, 28.63it/s]Epoch24,Loss:2.0441
Epoch25,Loss:1.8446
Epoch26,Loss:1.6723
13%|█▎ | 26/200 [00:00<00:06, 28.52it/s]Epoch27,Loss:1.5210
Epoch28,Loss:1.3906
14%|█▍ | 29/200 [00:01<00:05, 28.59it/s]Epoch29,Loss:1.2725
Epoch30,Loss:1.1697
Epoch31,Loss:1.0768
Epoch32,Loss:0.9947
16%|█▌ | 32/200 [00:01<00:05, 28.42it/s]Epoch33,Loss:0.9213
Epoch34,Loss:0.8535
Epoch35,Loss:0.7944
18%|█▊ | 35/200 [00:01<00:05, 28.44it/s]Epoch36,Loss:0.7393
Epoch37,Loss:0.6896
Epoch38,Loss:0.6445
20%|█▉ | 39/200 [00:01<00:05, 29.27it/s]Epoch39,Loss:0.6032
Epoch40,Loss:0.5652
Epoch41,Loss:0.5306
Epoch42,Loss:0.4984
22%|██▏ | 43/200 [00:01<00:05, 30.10it/s]Epoch43,Loss:0.4691
Epoch44,Loss:0.4417
Epoch45,Loss:0.4165
Epoch46,Loss:0.3933
24%|██▎ | 47/200 [00:01<00:05, 30.39it/s]Epoch47,Loss:0.3715
Epoch48,Loss:0.3511
Epoch49,Loss:0.3325
Epoch50,Loss:0.3148
Epoch51,Loss:0.2985
26%|██▌ | 51/200 [00:01<00:04, 30.76it/s]Epoch52,Loss:0.2831
Epoch53,Loss:0.2688
Epoch54,Loss:0.2553
28%|██▊ | 55/200 [00:01<00:04, 30.40it/s]Epoch55,Loss:0.2427
Epoch56,Loss:0.2308
Epoch57,Loss:0.2197
Epoch58,Loss:0.2092
Epoch59,Loss:0.1993
30%|██▉ | 59/200 [00:01<00:04, 30.61it/s]Epoch60,Loss:0.1900
Epoch61,Loss:0.1811
Epoch62,Loss:0.1729
Epoch63,Loss:0.1650
32%|███▏ | 63/200 [00:02<00:04, 30.92it/s]Epoch64,Loss:0.1576
Epoch65,Loss:0.1506
Epoch66,Loss:0.1439
Epoch67,Loss:0.1376
34%|███▎ | 67/200 [00:02<00:04, 30.84it/s]Epoch68,Loss:0.1316
Epoch69,Loss:0.1259
Epoch70,Loss:0.1206
Epoch71,Loss:0.1154
36%|███▌ | 71/200 [00:02<00:04, 30.96it/s]Epoch72,Loss:0.1105
Epoch73,Loss:0.1059
Epoch74,Loss:0.1015
Epoch75,Loss:0.0973
38%|███▊ | 75/200 [00:02<00:04, 30.45it/s]Epoch76,Loss:0.0933
Epoch77,Loss:0.0895
Epoch78,Loss:0.0859
Epoch79,Loss:0.0824
40%|███▉ | 79/200 [00:02<00:03, 30.67it/s]Epoch80,Loss:0.0791
Epoch81,Loss:0.0760
Epoch82,Loss:0.0730
Epoch83,Loss:0.0701
42%|████▏ | 83/200 [00:02<00:03, 30.44it/s]Epoch84,Loss:0.0673
Epoch85,Loss:0.0647
Epoch86,Loss:0.0622
44%|████▎ | 87/200 [00:02<00:03, 30.43it/s]Epoch87,Loss:0.0598
Epoch88,Loss:0.0575
Epoch89,Loss:0.0553
Epoch90,Loss:0.0532
46%|████▌ | 91/200 [00:03<00:03, 31.04it/s]Epoch91,Loss:0.0512
Epoch92,Loss:0.0493
Epoch93,Loss:0.0474
Epoch94,Loss:0.0456
48%|████▊ | 95/200 [00:03<00:03, 31.21it/s]Epoch95,Loss:0.0439
Epoch96,Loss:0.0423
Epoch97,Loss:0.0407
Epoch98,Loss:0.0392
Epoch99,Loss:0.0378
50%|████▉ | 99/200 [00:03<00:03, 31.29it/s]Epoch100,Loss:0.0364
Epoch101,Loss:0.0351
Epoch102,Loss:0.0338
52%|█████▏ | 103/200 [00:03<00:03, 30.79it/s]Epoch103,Loss:0.0326
Epoch104,Loss:0.0314
Epoch105,Loss:0.0302
Epoch106,Loss:0.0292
Epoch107,Loss:0.0281
54%|█████▎ | 107/200 [00:03<00:03, 30.82it/s]Epoch108,Loss:0.0271
Epoch109,Loss:0.0261
Epoch110,Loss:0.0252
Epoch111,Loss:0.0243
56%|█████▌ | 111/200 [00:03<00:02, 31.28it/s]Epoch112,Loss:0.0235
Epoch113,Loss:0.0226
Epoch114,Loss:0.0218
57%|█████▊ | 115/200 [00:03<00:02, 31.21it/s]Epoch115,Loss:0.0210
Epoch116,Loss:0.0203
Epoch117,Loss:0.0196
Epoch118,Loss:0.0189
Epoch119,Loss:0.0183
60%|█████▉ | 119/200 [00:03<00:02, 31.35it/s]Epoch120,Loss:0.0176
Epoch121,Loss:0.0170
Epoch122,Loss:0.0164
62%|██████▏ | 123/200 [00:04<00:02, 31.33it/s]Epoch123,Loss:0.0159
Epoch124,Loss:0.0153
Epoch125,Loss:0.0148
Epoch126,Loss:0.0143
64%|██████▎ | 127/200 [00:04<00:02, 31.32it/s]Epoch127,Loss:0.0138
Epoch128,Loss:0.0133
Epoch129,Loss:0.0128
Epoch130,Loss:0.0124
66%|██████▌ | 131/200 [00:04<00:02, 31.41it/s]Epoch131,Loss:0.0120
Epoch132,Loss:0.0116
Epoch133,Loss:0.0112
Epoch134,Loss:0.0108
68%|██████▊ | 135/200 [00:04<00:02, 30.69it/s]Epoch135,Loss:0.0104
Epoch136,Loss:0.0101
Epoch137,Loss:0.0097
Epoch138,Loss:0.0094
70%|██████▉ | 139/200 [00:04<00:01, 31.18it/s]Epoch139,Loss:0.0091
Epoch140,Loss:0.0088
Epoch141,Loss:0.0085
Epoch142,Loss:0.0082
Epoch143,Loss:0.0079
72%|███████▏ | 143/200 [00:04<00:01, 31.13it/s]Epoch144,Loss:0.0077
Epoch145,Loss:0.0074
Epoch146,Loss:0.0072
Epoch147,Loss:0.0069
74%|███████▎ | 147/200 [00:04<00:01, 31.31it/s]Epoch148,Loss:0.0067
Epoch149,Loss:0.0065
Epoch150,Loss:0.0062
76%|███████▌ | 151/200 [00:04<00:01, 31.24it/s]Epoch151,Loss:0.0060
Epoch152,Loss:0.0058
Epoch153,Loss:0.0056
Epoch154,Loss:0.0055
Epoch155,Loss:0.0053
78%|███████▊ | 155/200 [00:05<00:01, 30.78it/s]Epoch156,Loss:0.0051
Epoch157,Loss:0.0049
Epoch158,Loss:0.0048
80%|███████▉ | 159/200 [00:05<00:01, 31.31it/s]Epoch159,Loss:0.0046
Epoch160,Loss:0.0045
Epoch161,Loss:0.0043
Epoch162,Loss:0.0042
82%|████████▏ | 163/200 [00:05<00:01, 31.01it/s]Epoch163,Loss:0.0040
Epoch164,Loss:0.0039
Epoch165,Loss:0.0038
Epoch166,Loss:0.0036
84%|████████▎ | 167/200 [00:05<00:01, 30.89it/s]Epoch167,Loss:0.0035
Epoch168,Loss:0.0034
Epoch169,Loss:0.0033
Epoch170,Loss:0.0032
Epoch171,Loss:0.0031
86%|████████▌ | 171/200 [00:05<00:00, 30.44it/s]Epoch172,Loss:0.0030
Epoch173,Loss:0.0029
Epoch174,Loss:0.0028
88%|████████▊ | 175/200 [00:05<00:00, 30.64it/s]Epoch175,Loss:0.0027
Epoch176,Loss:0.0026
Epoch177,Loss:0.0025
Epoch178,Loss:0.0025
Epoch179,Loss:0.0024
90%|████████▉ | 179/200 [00:05<00:00, 31.14it/s]Epoch180,Loss:0.0023
Epoch181,Loss:0.0022
Epoch182,Loss:0.0021
92%|█████████▏| 183/200 [00:05<00:00, 31.02it/s]Epoch183,Loss:0.0021
Epoch184,Loss:0.0020
Epoch185,Loss:0.0019
Epoch186,Loss:0.0019
Epoch187,Loss:0.0018
94%|█████████▎| 187/200 [00:06<00:00, 31.08it/s]Epoch188,Loss:0.0018
Epoch189,Loss:0.0017
Epoch190,Loss:0.0016
96%|█████████▌| 191/200 [00:06<00:00, 31.42it/s]Epoch191,Loss:0.0016
Epoch192,Loss:0.0015
Epoch193,Loss:0.0015
Epoch194,Loss:0.0014
Epoch195,Loss:0.0014
98%|█████████▊| 195/200 [00:06<00:00, 31.08it/s]Epoch196,Loss:0.0014
Epoch197,Loss:0.0013
Epoch198,Loss:0.0013
100%|█████████▉| 199/200 [00:06<00:00, 31.29it/s]Epoch199,Loss:0.0012
100%|██████████| 200/200 [00:06<00:00, 30.60it/s]
Epoch200,Loss:0.0012
==================================================
['People', 'create', 'to', 'direct']
programs
==================================================
jiesu
['file_1']