一、核心参数(固定配置)
-
L=4:序列长度(4个Token)
-
D=512:单个Token总特征维度
-
h=8:多头注意力的头数
-
d_k=64:单个头的特征维度(D/h)
-
scale=1/8:缩放因子(1/√d_k,防止数值溢出)
-
INF_NEG=-1e9:替代负无穷(浮点型无法直接表示)
二、大致流程

#include <Eigen/Dense>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cmath>
// 核心参数定义
const int L = 4; // 序列长度
const int D = 512; // 总特征维度
const int h = 8; // 多头注意力头数
const int d_k = D / h; // 单头特征维度
const float scale = 1.0f / sqrt(static_cast<float>(d_k)); // 缩放因子
const float INF_NEG = -1e9; // 替代负无穷
// 按行Softmax归一化(行和为1,值非负)
Eigen::MatrixXf row_softmax(const Eigen::MatrixXf& mat) {
Eigen::MatrixXf result(mat.rows(), mat.cols());
for (int i = 0; i < mat.rows(); ++i) {
Eigen::VectorXf row = mat.row(i);
float max_val = row.maxCoeff(); // 防止指数溢出
Eigen::VectorXf exp_row = (row.array() - max_val).exp();
float sum_exp = exp_row.sum();
result.row(i) = exp_row / sum_exp; // 行归一化
}
return result;
}
// 生成因果掩码矩阵(下三角/对角线为0,上三角为-1e9)
Eigen::MatrixXf create_causal_mask(int seq_len) {
Eigen::MatrixXf mask = Eigen::MatrixXf::Zero(seq_len, seq_len);
for (int i = 0; i < seq_len; ++i) {
for (int j = 0; j < seq_len; ++j) {
if (j > i) mask(i, j) = INF_NEG; // 屏蔽未来Token
}
}
return mask;
}
// 带因果掩码的Scaled Dot-Product Attention
// 输入:Q/K/V(4×64)、mask(4×4) 输出:注意力加权结果(4×64)
Eigen::MatrixXf scaled_dot_product_attention(const Eigen::MatrixXf& Q, const Eigen::MatrixXf& K, const Eigen::MatrixXf& V, const Eigen::MatrixXf& mask) {
Eigen::MatrixXf qk_t = Q * K.transpose(); // QK^T计算
qk_t *= scale; // 缩放
qk_t += mask; // 集成因果掩码
Eigen::MatrixXf attn_weights = row_softmax(qk_t); // 注意力权重
return attn_weights * V; // 加权求和
}
// 多头自注意力实现(集成因果掩码)
// 输入:Q/K/V(4×512)、W_o(512×512) 输出:注意力结果(4×512),输出所有头权重
Eigen::MatrixXf multi_head_attention(const Eigen::MatrixXf& Q, const Eigen::MatrixXf& K, const Eigen::MatrixXf& V,
const Eigen::MatrixXf& W_o, Eigen::MatrixXf& attn_weights_all) {
int batch = L;
std::vector<Eigen::MatrixXf> head_outputs(h);
std::vector<Eigen::MatrixXf> head_weights(h);
Eigen::MatrixXf causal_mask = create_causal_mask(batch); // 生成掩码
// 拆分Q/K/V到各头并计算注意力
for (int i = 0; i < h; ++i) {
Eigen::MatrixXf Q_head = Q.block(0, i * d_k, batch, d_k);
Eigen::MatrixXf K_head = K.block(0, i * d_k, batch, d_k);
Eigen::MatrixXf V_head = V.block(0, i * d_k, batch, d_k);
Eigen::MatrixXf qk_t = Q_head * K_head.transpose() * scale;
qk_t += causal_mask; // 头级别集成掩码
head_weights[i] = row_softmax(qk_t);
head_outputs[i] = head_weights[i] * V_head;
}
// 拼接所有头输出
Eigen::MatrixXf concat_output(batch, D);
for (int i = 0; i < h; ++i) {
concat_output.block(0, i * d_k, batch, d_k) = head_outputs[i];
}
Eigen::MatrixXf final_output = concat_output * W_o; // 线性变换
// 拼接所有头权重(验证用)
attn_weights_all = Eigen::MatrixXf::Zero(batch * h, batch);
for (int i = 0; i < h; ++i) {
attn_weights_all.block(i * batch, 0, batch, batch) = head_weights[i];
}
return final_output;
}
// 验证注意力权重合法性(非负、行和为1)
void check_attn_weights(const Eigen::MatrixXf& weights) {
std::cout << "=== 注意力权重基础验证 ===" << std::endl;
bool all_non_negative = true;
for (int i = 0; i < weights.rows() && all_non_negative; ++i) {
for (int j = 0; j < weights.cols(); ++j) {
if (weights(i, j) < -1e-6) { all_non_negative = false; break; }
}
}
std::cout << "权重是否全非负:" << (all_non_negative ? "是" : "否") << std::endl;
bool row_sum_1 = true;
for (int i = 0; i < weights.rows() && row_sum_1; ++i) {
float sum = weights.row(i).sum();
if (fabs(sum - 1.0f) > 1e-4) { row_sum_1 = false; std::cout << "第" << i << "行和:" << sum << std::endl; }
}
std::cout << "权重是否行和为1:" << (row_sum_1 ? "是" : "否") << std::endl;
}
// 验证因果掩码效果(上三角权重是否接近0)
void check_causal_mask_effect(const Eigen::MatrixXf& weights, int seq_len) {
std::cout << "\n=== 因果掩码效果验证 ===" << std::endl;
bool upper_tri_zero = true;
const float eps = 1e-4;
for (int head = 0; head < h && upper_tri_zero; ++head) {
int row_start = head * seq_len;
for (int i = 0; i < seq_len && upper_tri_zero; ++i) {
for (int j = 0; j < seq_len; ++j) {
if (j > i && fabs(weights(row_start + i, j)) > eps) {
upper_tri_zero = false;
std::cout << "第" << head << "个头(" << i << "," << j << ")权重:" << weights(row_start + i, j) << std::endl;
}
}
}
}
std::cout << "注意力权重上三角是否接近0:" << (upper_tri_zero ? "是" : "否") << std::endl;
}
int main() {
std::srand(static_cast<unsigned int>(std::time(nullptr)));
Eigen::MatrixXf::Random(1, 1);
// 生成随机输入
Eigen::MatrixXf Q = Eigen::MatrixXf::Random(L, D);
Eigen::MatrixXf K = Eigen::MatrixXf::Random(L, D);
Eigen::MatrixXf V = Eigen::MatrixXf::Random(L, D);
Eigen::MatrixXf W_o = Eigen::MatrixXf::Random(D, D);
Eigen::MatrixXf attn_weights_all;
Eigen::MatrixXf mha_output = multi_head_attention(Q, K, V, W_o, attn_weights_all); // 执行多头注意力
// 输出验证结果
std::cout << "=== 输出维度验证 ===" << std::endl;
std::cout << "多头注意力输出维度:" << mha_output.rows() << "×" << mha_output.cols() << std::endl;
std::cout << "期望维度:4×512" << std::endl;
check_attn_weights(attn_weights_all);
check_causal_mask_effect(attn_weights_all, L);
std::cout << "\n=== 第一个头的注意力权重矩阵 ===" << std::endl;
std::cout << attn_weights_all.block(0, 0, L, L) << std::endl;
return 0;
}
-
必须在row_softmax之前,加到缩放后的Q×K^T上
-
通过INF_NEG使未来Token经softmax后权重趋近于0,实现因果约束
-
所有矩阵运算需保证维度一致(如Q×K^T为4×4,与掩码维度匹配)