婴儿版GPT

##提供一个完整的Transformer架构推理程序,字典随便建的,嵌入向量矩阵、输出矩阵等都是随机产生的,未经过训练,缩小版的GPT,种在掌握注意力机制运算过程,其输入结果也是经常变的,因为没有训练。

import numpy as np

--------------------------

步骤1:构建词典(把字变成ID)

--------------------------

vocab = {"我":0, "喜":1, "欢":2, "中":3, "国":4, "美":5, "食":6, "[END]":7}

vocab_size = len(vocab) # 词典大小

d_model = 8 # 每个token用8维向量(小尺寸方便演示)

输入句子:我喜欢中国 → 转成ID

input_text = "我喜欢中国"

input_ids = [vocab[c] for c in input_text]

seq_len = len(input_ids) # 自动匹配实际token长度,取消硬编码

print("输入文字:", input_text)

print("输入token ID:", input_ids)

print("实际序列长度:", seq_len)

--------------------------

步骤2:词嵌入(文字 → 向量)

--------------------------

随机初始化嵌入层

embedding = np.random.randn(vocab_size, d_model)

x = embedding[input_ids] # 形状:(5, 8)

print("\n输入变成向量形状:", x.shape)

--------------------------

步骤3:Transformer 核心 = 标准自注意力机制

--------------------------

定义稳定的softmax函数(防止指数爆炸)

def softmax(x):

exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) # 减去最大值保证数值稳定

return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

初始化 QKV 权重矩阵

Wq = np.random.randn(d_model, d_model)

Wk = np.random.randn(d_model, d_model)

Wv = np.random.randn(d_model, d_model)

计算 Q K V

Q = x @ Wq

K = x @ Wk

V = x @ Wv

计算标准注意力分数 + softmax归一化

scores = Q @ K.T / np.sqrt(d_model) # 形状:(5,5)

attention_weights = softmax(scores) # 注意力权重(归一化到0-1,和为1)

attention = attention_weights @ V # 标准自注意力输出,形状:(5,8)

--------------------------

步骤4:输出层 → 预测下一个词

--------------------------

output_layer = np.random.randn(d_model, vocab_size)

logits = attention[-1] @ output_layer # 取最后一个token预测下一个词

pred_id = np.argmax(logits)

把预测ID转回文字

idx2word = {v:k for k,v in vocab.items()}

pred_word = idx2word[pred_id]

--------------------------

最终输出

--------------------------

print("\n" + "="*50)

print(f"输入:{input_text}")

print(f"Transformer 预测下一个词:【 {pred_word} 】")

print(f"完整句子推测:{input_text} → {pred_word}")

输入文字: 我喜欢中国

输入token ID: [0, 1, 2, 3, 4]

实际序列长度: 5

输入变成向量形状: (5, 8)

==================================================

输入:我喜欢中国

Transformer 预测下一个词:【 食 】

完整句子推测:我喜欢中国 → 食

相关推荐
无边风月-风之羽翼2 小时前
omnilingual_asr在Nvidia Spark DGX中部署
python
蓝天守卫者联盟12 小时前
烧结机一氧化碳治理厂家技术路线与市场格局分析
大数据·人工智能·python
Ulyanov2 小时前
雷达信号处理核心算法与仿真实现
python·目标跟踪·信号处理·系统仿真·雷达电子对抗
用户0332126663672 小时前
使用 Python 压缩 PDF 文件的大小
python
姜太小白2 小时前
【Linux】CentOS 7 VNC 远程桌面配置
linux·python·centos
Ai.den2 小时前
Windows 安装 DeerFlow 2.0
人工智能·windows·python·ai
weixin_433179332 小时前
python - 存储数据
python
何伯特2 小时前
手撕Transformer:一个完整的机器翻译实例详解
深度学习·transformer·机器翻译
阿坤带你走近大数据2 小时前
数据API接口的数据源和目标源分别是什么?怎么设置?
java·python·api