windows下,c++的axv2+fma/avx-vnni加速计算demo

复制代码
#include <immintrin.h>
#include <intrin.h>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <string>
#include <utility>
#include <vector>

// ------------------------------------------------------------
// CPU 特性检测
// ------------------------------------------------------------
static bool os_supports_ymm_state() {
    int cpu_info[4] = {};
    __cpuid(cpu_info, 1);
    const bool osxsave = (cpu_info[2] & (1 << 27)) != 0;
    const bool avx = (cpu_info[2] & (1 << 28)) != 0;
    if (!osxsave || !avx) {
        return false;
    }
    const unsigned long long xcr0 = _xgetbv(0);
    return (xcr0 & 0x6) == 0x6;
}

static bool cpu_supports_avx2() {
    if (!os_supports_ymm_state()) {
        return false;
    }
    int cpu_info[4] = {};
    __cpuidex(cpu_info, 7, 0);
    return (cpu_info[1] & (1 << 5)) != 0;
}

static bool cpu_supports_avx_vnni() {
    if (!os_supports_ymm_state()) {
        return false;
    }
    int cpu_info[4] = {};
    __cpuidex(cpu_info, 7, 1);
    return (cpu_info[0] & (1 << 4)) != 0;
}

// ------------------------------------------------------------
// 点积计算内核(数学定义一致):
//   sum = Σ (a[i] * b[i]),其中 a 为 u8,b 为 s8。
// ------------------------------------------------------------

// 串行标量参考实现(单累加器)。
// 该版本故意保持"最朴素串行",便于理解与对照。
static int64_t scalar_dot_u8s8_serial(const uint8_t* a, const int8_t* b, int n) {
    int64_t sum = 0;
    for (int i = 0; i < n; ++i) {
        sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
    }
    return sum;
}

// 4 路累加器标量实现(减少依赖链长度)。
static int64_t scalar_dot_u8s8(const uint8_t* a, const int8_t* b, int n) {
    int64_t s0 = 0, s1 = 0, s2 = 0, s3 = 0;
    int i = 0;
    for (; i + 3 < n; i += 4) {
        s0 += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
        s1 += static_cast<int32_t>(a[i + 1]) * static_cast<int32_t>(b[i + 1]);
        s2 += static_cast<int32_t>(a[i + 2]) * static_cast<int32_t>(b[i + 2]);
        s3 += static_cast<int32_t>(a[i + 3]) * static_cast<int32_t>(b[i + 3]);
    }
    int64_t sum = s0 + s1 + s2 + s3;
    for (; i < n; ++i) {
        sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
    }
    return sum;
}

// AVX2 实现:
// 1) _mm256_maddubs_epi16:按字节做 u8*s8,并得到打包的 s16 部分和
// 2) _mm256_madd_epi16:将相邻 s16 与常量 1 做乘加,归并为 s32
// 使用 4 个相互独立的累加器提升 ILP(指令级并行度)。
static int64_t avx2_dot_u8s8(const uint8_t* a, const int8_t* b, int n) {
    __m256i acc0 = _mm256_setzero_si256();
    __m256i acc1 = _mm256_setzero_si256();
    __m256i acc2 = _mm256_setzero_si256();
    __m256i acc3 = _mm256_setzero_si256();
    const __m256i ones16 = _mm256_set1_epi16(1);
    int i = 0;
    for (; i + 127 < n; i += 128) {
        const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
        const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
        const __m256i p0 = _mm256_maddubs_epi16(va, vb);
        acc0 = _mm256_add_epi32(acc0, _mm256_madd_epi16(p0, ones16));

        const __m256i va1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 32));
        const __m256i vb1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 32));
        const __m256i p1 = _mm256_maddubs_epi16(va1, vb1);
        acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(p1, ones16));

        const __m256i va2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 64));
        const __m256i vb2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 64));
        const __m256i p2 = _mm256_maddubs_epi16(va2, vb2);
        acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(p2, ones16));

        const __m256i va3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 96));
        const __m256i vb3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 96));
        const __m256i p3 = _mm256_maddubs_epi16(va3, vb3);
        acc3 = _mm256_add_epi32(acc3, _mm256_madd_epi16(p3, ones16));
    }
    for (; i + 31 < n; i += 32) {
        const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
        const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
        const __m256i p = _mm256_maddubs_epi16(va, vb);
        acc0 = _mm256_add_epi32(acc0, _mm256_madd_epi16(p, ones16));
    }

    const __m256i acc32 = _mm256_add_epi32(_mm256_add_epi32(acc0, acc1), _mm256_add_epi32(acc2, acc3));
    alignas(32) int32_t lanes[8] = {};
    _mm256_store_si256(reinterpret_cast<__m256i*>(lanes), acc32);
    int64_t sum = 0;
    for (int k = 0; k < 8; ++k) {
        sum += lanes[k];
    }
    for (; i < n; ++i) {
        sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
    }
    return sum;
}

#if defined(__AVXVNNI__)
// AVX-VNNI 实现:
// _mm256_dpbusd_avx_epi32 会按每 4 个字节做点积,并累加到 s32 lane。
// 同样使用 4 个独立累加器。
static int64_t avx_vnni_dot_u8s8(const uint8_t* a, const int8_t* b, int n) {
    __m256i acc0 = _mm256_setzero_si256();
    __m256i acc1 = _mm256_setzero_si256();
    __m256i acc2 = _mm256_setzero_si256();
    __m256i acc3 = _mm256_setzero_si256();
    int i = 0;
    for (; i + 127 < n; i += 128) {
        const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
        const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
        acc0 = _mm256_dpbusd_avx_epi32(acc0, va, vb);

        const __m256i va1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 32));
        const __m256i vb1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 32));
        acc1 = _mm256_dpbusd_avx_epi32(acc1, va1, vb1);

        const __m256i va2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 64));
        const __m256i vb2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 64));
        acc2 = _mm256_dpbusd_avx_epi32(acc2, va2, vb2);

        const __m256i va3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 96));
        const __m256i vb3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 96));
        acc3 = _mm256_dpbusd_avx_epi32(acc3, va3, vb3);
    }
    for (; i + 31 < n; i += 32) {
        const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
        const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
        acc0 = _mm256_dpbusd_avx_epi32(acc0, va, vb);
    }

    const __m256i acc32 = _mm256_add_epi32(_mm256_add_epi32(acc0, acc1), _mm256_add_epi32(acc2, acc3));
    alignas(32) int32_t lanes[8] = {};
    _mm256_store_si256(reinterpret_cast<__m256i*>(lanes), acc32);
    int64_t sum = 0;
    for (int k = 0; k < 8; ++k) {
        sum += lanes[k];
    }
    for (; i < n; ++i) {
        sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
    }
    return sum;
}
#endif

// AVX2+FMA(float)示例:y[i] = a[i] * b[i] + c[i],并返回 sum(y) 用于防止优化。
static float avx2_fma_f32_fma_sum(const float* a, const float* b, const float* c, int n) {
    __m256 acc0 = _mm256_setzero_ps();
    __m256 acc1 = _mm256_setzero_ps();
    __m256 acc2 = _mm256_setzero_ps();
    __m256 acc3 = _mm256_setzero_ps();
    int i = 0;

    for (; i + 31 < n; i += 32) {
        const __m256 a0 = _mm256_loadu_ps(a + i);
        const __m256 b0 = _mm256_loadu_ps(b + i);
        const __m256 c0 = _mm256_loadu_ps(c + i);
        acc0 = _mm256_add_ps(acc0, _mm256_fmadd_ps(a0, b0, c0));

        const __m256 a1 = _mm256_loadu_ps(a + i + 8);
        const __m256 b1 = _mm256_loadu_ps(b + i + 8);
        const __m256 c1 = _mm256_loadu_ps(c + i + 8);
        acc1 = _mm256_add_ps(acc1, _mm256_fmadd_ps(a1, b1, c1));

        const __m256 a2 = _mm256_loadu_ps(a + i + 16);
        const __m256 b2 = _mm256_loadu_ps(b + i + 16);
        const __m256 c2 = _mm256_loadu_ps(c + i + 16);
        acc2 = _mm256_add_ps(acc2, _mm256_fmadd_ps(a2, b2, c2));

        const __m256 a3 = _mm256_loadu_ps(a + i + 24);
        const __m256 b3 = _mm256_loadu_ps(b + i + 24);
        const __m256 c3 = _mm256_loadu_ps(c + i + 24);
        acc3 = _mm256_add_ps(acc3, _mm256_fmadd_ps(a3, b3, c3));
    }

    __m256 acc = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
    alignas(32) float lanes[8] = {};
    _mm256_store_ps(lanes, acc);
    float sum = 0.0f;
    for (int k = 0; k < 8; ++k) {
        sum += lanes[k];
    }
    for (; i < n; ++i) {
        sum += a[i] * b[i] + c[i];
    }
    return sum;
}

// 标量参考:float 路径 sum(a[i] * b[i] + c[i]),用于 mixed 校验。
static float scalar_fma_f32_sum(const float* a, const float* b, const float* c, int n) {
    float sum = 0.0f;
    for (int i = 0; i < n; ++i) {
        sum += a[i] * b[i] + c[i];
    }
    return sum;
}

// 最小教学示例:用写死的 8 个 float,逐步解释 AVX2+FMA 在做什么。
static void explain_avx2_fma_minimal_demo() {
    // 固定示例数组(长度 8,刚好 1 个 256-bit 向量)。
    alignas(32) const float a[8] = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f };
    alignas(32) const float b[8] = { 0.5f, -1.0f, 1.5f, 2.0f, -0.5f, 3.0f, 0.25f, -2.0f };
    alignas(32) const float c[8] = { 0.1f, -0.7f, 2.2f, 1.3f, -1.1f, 0.6f, 3.0f, -0.4f };

    std::cout << "\n=== AVX2+FMA Minimal Demo (8 elements) ===" << std::endl;
    std::cout << "Target: y[i] = a[i] * b[i] + c[i]" << std::endl;
    std::cout << "a = [1,2,3,4,5,6,7,8]" << std::endl;
    std::cout << "b = [0.5,-1,1.5,2,-0.5,3,0.25,-2]" << std::endl;
    std::cout << "c = [0.1,-0.7,2.2,1.3,-1.1,0.6,3,-0.4]" << std::endl;

    // 第 1 步:标量逐项计算(作为"人能看懂"的参考过程)。
    alignas(32) float scalar_y[8] = {};
    float scalar_sum = 0.0f;
    for (int i = 0; i < 8; ++i) {
        scalar_y[i] = a[i] * b[i] + c[i];
        scalar_sum += scalar_y[i];
        std::cout << "Scalar step " << i
                  << ": " << a[i] << " * " << b[i] << " + " << c[i]
                  << " = " << scalar_y[i] << std::endl;
    }

    // 第 2 步:把 8 个数一次性加载到 256-bit 向量寄存器。
    const __m256 va = _mm256_load_ps(a);
    const __m256 vb = _mm256_load_ps(b);
    const __m256 vc = _mm256_load_ps(c);

    // 第 3 步:一条 FMA 指令同时完成 8 个 lane 的乘加。
    // 数学上等价于:vy[lane] = va[lane] * vb[lane] + vc[lane]
    const __m256 vy = _mm256_fmadd_ps(va, vb, vc);

    // 第 4 步:把向量结果写回数组,便于观察每个 lane 的值。
    alignas(32) float avx_y[8] = {};
    _mm256_store_ps(avx_y, vy);

    float avx_sum = 0.0f;
    for (int i = 0; i < 8; ++i) {
        avx_sum += avx_y[i];
        std::cout << "AVX lane " << i << ": "
                  << a[i] << " * " << b[i] << " + " << c[i]
                  << " = " << avx_y[i] << std::endl;
    }

    std::cout << "Scalar sum = " << scalar_sum << ", AVX sum = " << avx_sum
              << ", abs_err = " << std::abs(scalar_sum - avx_sum) << std::endl;
}

#if defined(__AVXVNNI__)
// 最小教学示例:用写死的 32 组 u8/s8,逐步解释 AVX-VNNI 在做什么。
static void explain_avx_vnni_minimal_demo() {
    // 直接硬编码 32 个元素,便于逐项观察标量与 VNNI 的对应关系。
    alignas(32) const uint8_t a[32] = {
        1, 2, 3, 4, 5, 6, 7, 8,
        9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24,
        25, 26, 27, 28, 29, 30, 31, 32
    };
    alignas(32) const int8_t b[32] = {
        -3, -2, -1, 0, 1, 2, 3, -3,
        -2, -1, 0, 1, 2, 3, -3, -2,
        -1, 0, 1, 2, 3, -3, -2, -1,
        0, 1, 2, 3, -3, -2, -1, 0
    };

    std::cout << "\n=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===" << std::endl;
    std::cout << "Target: dot = sum(a[i] * b[i])" << std::endl;
    std::cout << "a = [1,2,3,...,32]" << std::endl;
    std::cout << "b = [-3,-2,-1,0,1,2,3,-3,...]" << std::endl;

    // 第 1 步:标量参考,逐项打印。
    int32_t scalar_sum = 0;
    for (int i = 0; i < 32; ++i) {
        const int32_t term = static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
        scalar_sum += term;
        std::cout << "Scalar step " << i
                  << ": " << static_cast<int32_t>(a[i]) << " * " << static_cast<int32_t>(b[i])
                  << " = " << term << ", running_sum = " << scalar_sum << std::endl;
    }

    // 第 2 步:向量加载 32 个字节(a 和 b 各 32 个)。
    const __m256i va = _mm256_load_si256(reinterpret_cast<const __m256i*>(a));
    const __m256i vb = _mm256_load_si256(reinterpret_cast<const __m256i*>(b));

    // 第 3 步:VNNI 核心指令。
    // acc[lane] += a[4*lane+0]*b[4*lane+0] + ... + a[4*lane+3]*b[4*lane+3]
    __m256i acc = _mm256_setzero_si256();
    acc = _mm256_dpbusd_avx_epi32(acc, va, vb);

    // 第 4 步:观察 8 个 lane 的部分和,再做总和。
    alignas(32) int32_t lanes[8] = {};
    _mm256_store_si256(reinterpret_cast<__m256i*>(lanes), acc);
    int32_t vnni_sum = 0;
    for (int lane = 0; lane < 8; ++lane) {
        vnni_sum += lanes[lane];
        std::cout << "VNNI lane " << lane << " partial_sum = " << lanes[lane] << std::endl;
    }

    std::cout << "Scalar sum = " << scalar_sum
              << ", VNNI sum = " << vnni_sum
              << ", abs_err = " << std::abs(scalar_sum - vnni_sum) << std::endl;
}
#endif

#if defined(__AVXVNNI__)
// 混合计算:单轮里同时执行 AVX2+FMA(float)与 AVX-VNNI(u8/s8)。
static std::pair<float, int64_t> mixed_avx2_fma_vnni(
    const float* fa, const float* fb, const float* fc, int fn,
    const uint8_t* ia, const int8_t* ib, int in) {
    const float fsum = avx2_fma_f32_fma_sum(fa, fb, fc, fn);
    const int64_t isum = avx_vnni_dot_u8s8(ia, ib, in);
    return { fsum, isum };
}
#endif

template <typename Func>
static double benchmark_us(Func&& fn, int iterations, volatile int64_t& sink) {
    const auto begin = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < iterations; ++i) {
        sink ^= fn();
    }
    const auto end = std::chrono::high_resolution_clock::now();
    return std::chrono::duration<double, std::micro>(end - begin).count();
}

struct BenchConfig {
    std::string name;
    int effective_n;
    int warmup;
    int iterations;
    bool run_simd;
};

static void print_method_result(
    const std::string& method,
    int64_t result,
    int64_t abs_err,
    double rel_err,
    double total_us,
    int iterations,
    double speedup) {
    std::cout << "[" << method << "] result = " << result
              << ", abs_err = " << abs_err
              << ", rel_err = " << rel_err
              << ", total = " << total_us << " us"
              << ", per pass = " << (total_us / iterations) << " us";
    if (speedup > 0.0) {
        std::cout << ", speedup = " << speedup << "x";
    }
    std::cout << std::endl;
}

static void run_bench_suite(const BenchConfig& cfg, const std::vector<uint8_t>& a, const std::vector<int8_t>& b) {
    volatile int64_t sink = 0;
    const uint8_t* pa = a.data();
    const int8_t* pb = b.data();

    // 本组测试中,以"4 路累加器标量实现"作为 SIMD 对比基线。
    for (int i = 0; i < cfg.warmup; ++i) {
        sink ^= scalar_dot_u8s8(pa, pb, cfg.effective_n);
    }
    const double scalar_us = benchmark_us([&]() { return scalar_dot_u8s8(pa, pb, cfg.effective_n); }, cfg.iterations, sink);
    const int64_t scalar_res = scalar_dot_u8s8(pa, pb, cfg.effective_n);

    std::cout << "\n=== " << cfg.name << " ===" << std::endl;
    std::cout << "effective_n = " << cfg.effective_n
              << ", warmup = " << cfg.warmup
              << ", iterations = " << cfg.iterations << std::endl;
    print_method_result("Scalar", scalar_res, 0, 0.0, scalar_us, cfg.iterations, 0.0);

    if (!cfg.run_simd) {
        return;
    }

    if (cpu_supports_avx2()) {
        for (int i = 0; i < cfg.warmup; ++i) {
            sink ^= avx2_dot_u8s8(pa, pb, cfg.effective_n);
        }
        const double avx2_us = benchmark_us([&]() { return avx2_dot_u8s8(pa, pb, cfg.effective_n); }, cfg.iterations, sink);
        const int64_t avx2_res = avx2_dot_u8s8(pa, pb, cfg.effective_n);
        const int64_t err = avx2_res - scalar_res;
        const double rel_err = (scalar_res == 0) ? 0.0 : std::abs(static_cast<double>(err) / static_cast<double>(scalar_res));
        print_method_result("AVX2 ", avx2_res, err, rel_err, avx2_us, cfg.iterations, scalar_us / avx2_us);
    }
    else {
        std::cout << "[AVX2 ] unsupported at runtime." << std::endl;
    }

#if defined(__AVXVNNI__)
    if (cpu_supports_avx_vnni()) {
        for (int i = 0; i < cfg.warmup; ++i) {
            sink ^= avx_vnni_dot_u8s8(pa, pb, cfg.effective_n);
        }
        const double vnni_us = benchmark_us([&]() { return avx_vnni_dot_u8s8(pa, pb, cfg.effective_n); }, cfg.iterations, sink);
        const int64_t vnni_res = avx_vnni_dot_u8s8(pa, pb, cfg.effective_n);
        const int64_t err = vnni_res - scalar_res;
        const double rel_err = (scalar_res == 0) ? 0.0 : std::abs(static_cast<double>(err) / static_cast<double>(scalar_res));
        print_method_result("VNNI ", vnni_res, err, rel_err, vnni_us, cfg.iterations, scalar_us / vnni_us);
    }
    else {
        std::cout << "[VNNI ] runtime unsupported." << std::endl;
    }
#else
    std::cout << "[VNNI ] not compiled with AVX-VNNI intrinsics." << std::endl;
#endif
}

int main() {
    int n = 0;
    std::cout << "Input n (u8/s8 vector length): ";
    if (!(std::cin >> n) || n <= 0) {
        std::cerr << "Invalid input." << std::endl;
        return 1;
    }

    // 演示用数组说明:
    // a[i](u8):0..250 循环
    // b[i](s8):-63..63 循环
    // 前几个元素示意:
    //   a = [0,1,2,3,4,...]
    //   b = [-63,-62,-61,-60,-59,...]
    // 所有测试组复用同一份输入数据,确保横向对比公平。
    std::vector<uint8_t> a(n);
    std::vector<int8_t> b(n);
    for (int i = 0; i < n; ++i) {
        a[i] = static_cast<uint8_t>(i % 251);
        b[i] = static_cast<int8_t>((i % 127) - 63);
    }

    std::cout << std::fixed << std::setprecision(6);
    std::cout << "n(input) = " << n << std::endl;

    // 先跑一个最小教学示例,帮助理解 AVX2+FMA 的逐步计算过程。
    if (cpu_supports_avx2()) {
        explain_avx2_fma_minimal_demo();
    }
    else {
        std::cout << "\n=== AVX2+FMA Minimal Demo (8 elements) ===" << std::endl;
        std::cout << "Current CPU/OS does not support AVX2, demo skipped." << std::endl;
    }

#if defined(__AVXVNNI__)
    if (cpu_supports_avx_vnni()) {
        explain_avx_vnni_minimal_demo();
    }
    else {
        std::cout << "\n=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===" << std::endl;
        std::cout << "Current CPU/OS does not support AVX-VNNI, demo skipped." << std::endl;
    }
#else
    std::cout << "\n=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===" << std::endl;
    std::cout << "This build does not enable AVX-VNNI intrinsics, demo skipped." << std::endl;
#endif

    // 1) 串行基线测试:
    //    仅使用单累加器标量内核。
    {
        volatile int64_t sink = 0;
        int serial_iters = 2000;
        if (n <= 256) {
            serial_iters = 50000;
        }
        else if (n <= 4096) {
            serial_iters = 10000;
        }
        const int serial_warmup = 100;
        for (int i = 0; i < serial_warmup; ++i) {
            sink ^= scalar_dot_u8s8_serial(a.data(), b.data(), n);
        }
        const double serial_us = benchmark_us([&]() { return scalar_dot_u8s8_serial(a.data(), b.data(), n); }, serial_iters, sink);
        const int64_t serial_res = scalar_dot_u8s8_serial(a.data(), b.data(), n);

        std::cout << "\n=== Serial (single accumulator scalar) ===" << std::endl;
        std::cout << "effective_n = " << n
                  << ", warmup = " << serial_warmup
                  << ", iterations = " << serial_iters << std::endl;
        print_method_result("Serial", serial_res, 0, 0.0, serial_us, serial_iters, 0.0);
    }

    // 2) Streaming 测试:
    //    使用完整 n,更接近真实线性内存流式访问场景。
    {
        BenchConfig cfg{};
        cfg.name = "Streaming (full input size)";
        cfg.effective_n = n;
        cfg.warmup = 100;
        cfg.iterations = 2000;
        if (n <= 256) cfg.iterations = 50000;
        else if (n <= 4096) cfg.iterations = 10000;
        cfg.run_simd = true;
        run_bench_suite(cfg, a, b);
    }

    // 3) L1/L2-friendly 测试:
    //    将 effective_n 限制为 <=4096,使数据更易驻留缓存,突出计算吞吐能力。
    {
        BenchConfig cfg{};
        cfg.name = "L1/L2-friendly (cache-resident working set)";
        cfg.effective_n = (n < 4096) ? n : 4096;
        cfg.warmup = 500;
        cfg.iterations = 80000;
        cfg.run_simd = true;
        run_bench_suite(cfg, a, b);
    }

    // 4) 混合压榨测试(AVX2+FMA + VNNI):
    //    在同一轮中同时执行 float FMA 和 int8 VNNI,观察混合算子吞吐。
    if (cpu_supports_avx2()
#if defined(__AVXVNNI__)
        && cpu_supports_avx_vnni()
#else
        && false
#endif
        ) {
        // 这里改为与主测试一致:float 与 int 路径都使用同一个 n。
        const int f_n = n;
        const int i_n = n;

        std::vector<float> fa(f_n), fb(f_n), fc(f_n);
        for (int i = 0; i < f_n; ++i) {
            fa[i] = static_cast<float>((i % 113) * 0.03125f + 1.0f);
            fb[i] = static_cast<float>((i % 97) * 0.015625f + 0.5f);
            fc[i] = static_cast<float>((i % 89) * 0.0078125f + 0.25f);
        }

        int warmup = 100;
        int iters = 2000;
        if (n <= 256) {
            iters = 50000;
        }
        else if (n <= 4096) {
            iters = 10000;
        }
        volatile int64_t mix_sink = 0;
        volatile float mix_f_sink = 0.0f;

#if defined(__AVXVNNI__)
        for (int i = 0; i < warmup; ++i) {
            const auto r = mixed_avx2_fma_vnni(fa.data(), fb.data(), fc.data(), f_n, a.data(), b.data(), i_n);
            mix_f_sink += r.first;
            mix_sink ^= r.second;
        }

        const auto begin = std::chrono::high_resolution_clock::now();
        for (int i = 0; i < iters; ++i) {
            const auto r = mixed_avx2_fma_vnni(fa.data(), fb.data(), fc.data(), f_n, a.data(), b.data(), i_n);
            mix_f_sink += r.first;
            mix_sink ^= r.second;
        }
        const auto end = std::chrono::high_resolution_clock::now();
        const double mix_us = std::chrono::duration<double, std::micro>(end - begin).count();
        const auto final_r = mixed_avx2_fma_vnni(fa.data(), fb.data(), fc.data(), f_n, a.data(), b.data(), i_n);
        const float float_scalar_ref = scalar_fma_f32_sum(fa.data(), fb.data(), fc.data(), f_n);
        const int64_t int_scalar_ref = scalar_dot_u8s8_serial(a.data(), b.data(), i_n);
        const float float_abs_err = std::abs(final_r.first - float_scalar_ref);
        const int64_t int_abs_err = final_r.second - int_scalar_ref;

        std::cout << "\n=== Mixed Throughput (AVX2+FMA + VNNI) ===" << std::endl;
        std::cout << "float_n = " << f_n << ", int_n = " << i_n
                  << ", warmup = " << warmup << ", iterations = " << iters << std::endl;
        std::cout << "[Mixed] float_result = " << final_r.first
                  << ", int_result = " << final_r.second
                  << ", total = " << mix_us << " us"
                  << ", per pass = " << (mix_us / iters) << " us" << std::endl;
        std::cout << "[Mixed-Check] float_scalar_ref = " << float_scalar_ref
                  << ", float_abs_err = " << float_abs_err << std::endl;
        std::cout << "[Mixed-Check] int_scalar_ref = " << int_scalar_ref
                  << ", int_abs_err = " << int_abs_err << std::endl;
#endif
        (void)mix_sink;
        (void)mix_f_sink;
    }
    else {
        std::cout << "\n=== Mixed Throughput (AVX2+FMA + VNNI) ===" << std::endl;
        std::cout << "[Mixed] skipped: requires AVX2 and AVX-VNNI runtime/compiler support." << std::endl;
    }

    return 0;
}

运行结果可以看到,内核中的特定加速指令集确实速度会快很多

Input n (u8/s8 vector length): 1024

n(input) = 1024

=== AVX2+FMA Minimal Demo (8 elements) ===

Target: yi = ai * bi + ci

a = 1,2,3,4,5,6,7,8

b = 0.5,-1,1.5,2,-0.5,3,0.25,-2

c = 0.1,-0.7,2.2,1.3,-1.1,0.6,3,-0.4

Scalar step 0: 1.000000 * 0.500000 + 0.100000 = 0.600000

Scalar step 1: 2.000000 * -1.000000 + -0.700000 = -2.700000

Scalar step 2: 3.000000 * 1.500000 + 2.200000 = 6.700000

Scalar step 3: 4.000000 * 2.000000 + 1.300000 = 9.300000

Scalar step 4: 5.000000 * -0.500000 + -1.100000 = -3.600000

Scalar step 5: 6.000000 * 3.000000 + 0.600000 = 18.600000

Scalar step 6: 7.000000 * 0.250000 + 3.000000 = 4.750000

Scalar step 7: 8.000000 * -2.000000 + -0.400000 = -16.400000

AVX lane 0: 1.000000 * 0.500000 + 0.100000 = 0.600000

AVX lane 1: 2.000000 * -1.000000 + -0.700000 = -2.700000

AVX lane 2: 3.000000 * 1.500000 + 2.200000 = 6.700000

AVX lane 3: 4.000000 * 2.000000 + 1.300000 = 9.300000

AVX lane 4: 5.000000 * -0.500000 + -1.100000 = -3.600000

AVX lane 5: 6.000000 * 3.000000 + 0.600000 = 18.600000

AVX lane 6: 7.000000 * 0.250000 + 3.000000 = 4.750000

AVX lane 7: 8.000000 * -2.000000 + -0.400000 = -16.400000

Scalar sum = 17.250002, AVX sum = 17.250002, abs_err = 0.000000

=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===

Target: dot = sum(ai * bi)

a = 1,2,3,...,32

b = -3,-2,-1,0,1,2,3,-3,...

Scalar step 0: 1 * -3 = -3, running_sum = -3

Scalar step 1: 2 * -2 = -4, running_sum = -7

Scalar step 2: 3 * -1 = -3, running_sum = -10

Scalar step 3: 4 * 0 = 0, running_sum = -10

Scalar step 4: 5 * 1 = 5, running_sum = -5

Scalar step 5: 6 * 2 = 12, running_sum = 7

Scalar step 6: 7 * 3 = 21, running_sum = 28

Scalar step 7: 8 * -3 = -24, running_sum = 4

Scalar step 8: 9 * -2 = -18, running_sum = -14

Scalar step 9: 10 * -1 = -10, running_sum = -24

Scalar step 10: 11 * 0 = 0, running_sum = -24

Scalar step 11: 12 * 1 = 12, running_sum = -12

Scalar step 12: 13 * 2 = 26, running_sum = 14

Scalar step 13: 14 * 3 = 42, running_sum = 56

Scalar step 14: 15 * -3 = -45, running_sum = 11

Scalar step 15: 16 * -2 = -32, running_sum = -21

Scalar step 16: 17 * -1 = -17, running_sum = -38

Scalar step 17: 18 * 0 = 0, running_sum = -38

Scalar step 18: 19 * 1 = 19, running_sum = -19

Scalar step 19: 20 * 2 = 40, running_sum = 21

Scalar step 20: 21 * 3 = 63, running_sum = 84

Scalar step 21: 22 * -3 = -66, running_sum = 18

Scalar step 22: 23 * -2 = -46, running_sum = -28

Scalar step 23: 24 * -1 = -24, running_sum = -52

Scalar step 24: 25 * 0 = 0, running_sum = -52

Scalar step 25: 26 * 1 = 26, running_sum = -26

Scalar step 26: 27 * 2 = 54, running_sum = 28

Scalar step 27: 28 * 3 = 84, running_sum = 112

Scalar step 28: 29 * -3 = -87, running_sum = 25

Scalar step 29: 30 * -2 = -60, running_sum = -35

Scalar step 30: 31 * -1 = -31, running_sum = -66

Scalar step 31: 32 * 0 = 0, running_sum = -66

VNNI lane 0 partial_sum = -10

VNNI lane 1 partial_sum = 14

VNNI lane 2 partial_sum = -16

VNNI lane 3 partial_sum = -9

VNNI lane 4 partial_sum = 42

VNNI lane 5 partial_sum = -73

VNNI lane 6 partial_sum = 164

VNNI lane 7 partial_sum = -178

Scalar sum = -66, VNNI sum = -66, abs_err = 0

=== Serial (single accumulator scalar) ===

effective_n = 1024, warmup = 100, iterations = 10000

Serial result = 913898, abs_err = 0, rel_err = 0.000000, total = 7224.700000 us, per pass = 0.722470 us

=== Streaming (full input size) ===

effective_n = 1024, warmup = 100, iterations = 10000

Scalar result = 913898, abs_err = 0, rel_err = 0.000000, total = 6741.500000 us, per pass = 0.674150 us

AVX2 result = 913898, abs_err = 0, rel_err = 0.000000, total = 1121.200000 us, per pass = 0.112120 us, speedup = 6.012754x

VNNI result = 913898, abs_err = 0, rel_err = 0.000000, total = 780.200000 us, per pass = 0.078020 us, speedup = 8.640733x

=== L1/L2-friendly (cache-resident working set) ===

effective_n = 1024, warmup = 500, iterations = 80000

Scalar result = 913898, abs_err = 0, rel_err = 0.000000, total = 51868.200000 us, per pass = 0.648352 us

AVX2 result = 913898, abs_err = 0, rel_err = 0.000000, total = 8255.100000 us, per pass = 0.103189 us, speedup = 6.283170x

VNNI result = 913898, abs_err = 0, rel_err = 0.000000, total = 6119.500000 us, per pass = 0.076494 us, speedup = 8.475889x

=== Mixed Throughput (AVX2+FMA + VNNI) ===

float_n = 1024, int_n = 1024, warmup = 100, iterations = 10000

Mixed float_result = 4065.320801, int_result = 913898, total = 4289.400000 us, per pass = 0.428940 us

Mixed-Check float_scalar_ref = 4065.320801, float_abs_err = 0.000000

Mixed-Check int_scalar_ref = 913898, int_abs_err = 0

相关推荐
计算机安禾几秒前
【算法分析与设计】第37篇:平面扫描与线段交问题
java·大数据·数据库·算法·机器学习
8Qi82 分钟前
LeetCode 236. 二叉树的最近公共祖先(LCA)
算法·leetcode·二叉树·递归·lca·后序遍历
兰令水5 分钟前
leecodecode【二叉树排序+最近公共祖先】【2026.6.2打卡-java版本】
java·数据结构·算法·leetcode
人道领域6 分钟前
【LeetCode刷题日记】77&&216.回溯算法剪枝优化在组合问题中的应用
java·算法·leetcode
Deepoch9 分钟前
Deepoc数学大模型:以低幻觉特性护航半导体精准设计与制造
大数据·人工智能·算法·半导体·deepoc
诸葛务农9 分钟前
共沸脱水技术及其在光刻胶用PGMEA纯化中的应用(上)
java·数据库·算法
£suPerpanda10 分钟前
AtCoder Beginner Contest 453
c++·算法
蜗牛~turbo15 分钟前
金蝶云星空 二开得到来源单单据体2数据包
windows·c#·金蝶·dynamicobject
xxxxxue17 分钟前
Windows 通过 右键菜单 调用 Python 脚本
开发语言·windows·python·右键菜单
词元Max19 分钟前
4.2 决策树与随机森林
算法·决策树·随机森林