掩码(Mask)机制 结合 多头自注意力函数

一、核心参数(固定配置)

  • L=4:序列长度(4个Token)

  • D=512:单个Token总特征维度

  • h=8:多头注意力的头数

  • d_k=64:单个头的特征维度(D/h)

  • scale=1/8:缩放因子(1/√d_k,防止数值溢出)

  • INF_NEG=-1e9:替代负无穷(浮点型无法直接表示)

二、大致流程

复制代码
#include <Eigen/Dense>
#include <iostream>
#include <cstdlib>
#include <ctime>
#include <cmath>

// 核心参数定义
const int L = 4;    // 序列长度
const int D = 512;  // 总特征维度
const int h = 8;    // 多头注意力头数
const int d_k = D / h;  // 单头特征维度
const float scale = 1.0f / sqrt(static_cast<float>(d_k)); // 缩放因子
const float INF_NEG = -1e9; // 替代负无穷

// 按行Softmax归一化(行和为1,值非负)
Eigen::MatrixXf row_softmax(const Eigen::MatrixXf& mat) {
    Eigen::MatrixXf result(mat.rows(), mat.cols());
    for (int i = 0; i < mat.rows(); ++i) {
        Eigen::VectorXf row = mat.row(i);
        float max_val = row.maxCoeff(); // 防止指数溢出
        Eigen::VectorXf exp_row = (row.array() - max_val).exp();
        float sum_exp = exp_row.sum();
        result.row(i) = exp_row / sum_exp; // 行归一化
    }
    return result;
}

// 生成因果掩码矩阵(下三角/对角线为0,上三角为-1e9)
Eigen::MatrixXf create_causal_mask(int seq_len) {
    Eigen::MatrixXf mask = Eigen::MatrixXf::Zero(seq_len, seq_len);
    for (int i = 0; i < seq_len; ++i) {
        for (int j = 0; j < seq_len; ++j) {
            if (j > i) mask(i, j) = INF_NEG; // 屏蔽未来Token
        }
    }
    return mask;
}

// 带因果掩码的Scaled Dot-Product Attention
// 输入:Q/K/V(4×64)、mask(4×4)  输出:注意力加权结果(4×64)
Eigen::MatrixXf scaled_dot_product_attention(const Eigen::MatrixXf& Q, const Eigen::MatrixXf& K, const Eigen::MatrixXf& V, const Eigen::MatrixXf& mask) {
    Eigen::MatrixXf qk_t = Q * K.transpose(); // QK^T计算
    qk_t *= scale; // 缩放
    qk_t += mask; // 集成因果掩码
    Eigen::MatrixXf attn_weights = row_softmax(qk_t); // 注意力权重
    return attn_weights * V; // 加权求和
}

// 多头自注意力实现(集成因果掩码)
// 输入:Q/K/V(4×512)、W_o(512×512)  输出:注意力结果(4×512),输出所有头权重
Eigen::MatrixXf multi_head_attention(const Eigen::MatrixXf& Q, const Eigen::MatrixXf& K, const Eigen::MatrixXf& V,
    const Eigen::MatrixXf& W_o, Eigen::MatrixXf& attn_weights_all) {
    int batch = L;
    std::vector<Eigen::MatrixXf> head_outputs(h);
    std::vector<Eigen::MatrixXf> head_weights(h);

    Eigen::MatrixXf causal_mask = create_causal_mask(batch); // 生成掩码

    // 拆分Q/K/V到各头并计算注意力
    for (int i = 0; i < h; ++i) {
        Eigen::MatrixXf Q_head = Q.block(0, i * d_k, batch, d_k);
        Eigen::MatrixXf K_head = K.block(0, i * d_k, batch, d_k);
        Eigen::MatrixXf V_head = V.block(0, i * d_k, batch, d_k);

        Eigen::MatrixXf qk_t = Q_head * K_head.transpose() * scale;
        qk_t += causal_mask; // 头级别集成掩码
        head_weights[i] = row_softmax(qk_t);
        head_outputs[i] = head_weights[i] * V_head;
    }

    // 拼接所有头输出
    Eigen::MatrixXf concat_output(batch, D);
    for (int i = 0; i < h; ++i) {
        concat_output.block(0, i * d_k, batch, d_k) = head_outputs[i];
    }

    Eigen::MatrixXf final_output = concat_output * W_o; // 线性变换

    // 拼接所有头权重(验证用)
    attn_weights_all = Eigen::MatrixXf::Zero(batch * h, batch);
    for (int i = 0; i < h; ++i) {
        attn_weights_all.block(i * batch, 0, batch, batch) = head_weights[i];
    }

    return final_output;
}

// 验证注意力权重合法性(非负、行和为1)
void check_attn_weights(const Eigen::MatrixXf& weights) {
    std::cout << "=== 注意力权重基础验证 ===" << std::endl;
    bool all_non_negative = true;
    for (int i = 0; i < weights.rows() && all_non_negative; ++i) {
        for (int j = 0; j < weights.cols(); ++j) {
            if (weights(i, j) < -1e-6) { all_non_negative = false; break; }
        }
    }
    std::cout << "权重是否全非负:" << (all_non_negative ? "是" : "否") << std::endl;

    bool row_sum_1 = true;
    for (int i = 0; i < weights.rows() && row_sum_1; ++i) {
        float sum = weights.row(i).sum();
        if (fabs(sum - 1.0f) > 1e-4) { row_sum_1 = false; std::cout << "第" << i << "行和:" << sum << std::endl; }
    }
    std::cout << "权重是否行和为1:" << (row_sum_1 ? "是" : "否") << std::endl;
}

// 验证因果掩码效果(上三角权重是否接近0)
void check_causal_mask_effect(const Eigen::MatrixXf& weights, int seq_len) {
    std::cout << "\n=== 因果掩码效果验证 ===" << std::endl;
    bool upper_tri_zero = true;
    const float eps = 1e-4;
    for (int head = 0; head < h && upper_tri_zero; ++head) {
        int row_start = head * seq_len;
        for (int i = 0; i < seq_len && upper_tri_zero; ++i) {
            for (int j = 0; j < seq_len; ++j) {
                if (j > i && fabs(weights(row_start + i, j)) > eps) {
                    upper_tri_zero = false;
                    std::cout << "第" << head << "个头(" << i << "," << j << ")权重:" << weights(row_start + i, j) << std::endl;
                }
            }
        }
    }
    std::cout << "注意力权重上三角是否接近0:" << (upper_tri_zero ? "是" : "否") << std::endl;
}

int main() {
    std::srand(static_cast<unsigned int>(std::time(nullptr)));
    Eigen::MatrixXf::Random(1, 1);

    // 生成随机输入
    Eigen::MatrixXf Q = Eigen::MatrixXf::Random(L, D);
    Eigen::MatrixXf K = Eigen::MatrixXf::Random(L, D);
    Eigen::MatrixXf V = Eigen::MatrixXf::Random(L, D);
    Eigen::MatrixXf W_o = Eigen::MatrixXf::Random(D, D);

    Eigen::MatrixXf attn_weights_all;
    Eigen::MatrixXf mha_output = multi_head_attention(Q, K, V, W_o, attn_weights_all); // 执行多头注意力

    // 输出验证结果
    std::cout << "=== 输出维度验证 ===" << std::endl;
    std::cout << "多头注意力输出维度:" << mha_output.rows() << "×" << mha_output.cols() << std::endl;
    std::cout << "期望维度:4×512" << std::endl;

    check_attn_weights(attn_weights_all);
    check_causal_mask_effect(attn_weights_all, L);

    std::cout << "\n=== 第一个头的注意力权重矩阵 ===" << std::endl;
    std::cout << attn_weights_all.block(0, 0, L, L) << std::endl;

    return 0;
}
  • 必须在row_softmax之前,加到缩放后的Q×K^T上

  • 通过INF_NEG使未来Token经softmax后权重趋近于0,实现因果约束

  • 所有矩阵运算需保证维度一致(如Q×K^T为4×4,与掩码维度匹配)

相关推荐
会叫的恐龙3 小时前
C++ 核心知识点汇总(第六日)(字符串)
c++·算法·字符串
小糯米6013 小时前
C++顺序表和vector
开发语言·c++·算法
We་ct4 小时前
LeetCode 56. 合并区间:区间重叠问题的核心解法与代码解析
前端·算法·leetcode·typescript
Lionel6894 小时前
分步实现 Flutter 鸿蒙轮播图核心功能(搜索框 + 指示灯)
算法·图搜索算法
小妖6664 小时前
js 实现快速排序算法
数据结构·算法·排序算法
xsyaaaan4 小时前
代码随想录Day30动态规划:背包问题二维_背包问题一维_416分割等和子集
算法·动态规划
zheyutao5 小时前
字符串哈希
算法
A尘埃5 小时前
保险公司车险理赔欺诈检测(随机森林)
算法·随机森林·机器学习
大江东去浪淘尽千古风流人物6 小时前
【VLN】VLN(Vision-and-Language Navigation视觉语言导航)算法本质,范式难点及解决方向(1)
人工智能·python·算法