#include <immintrin.h>
#include <intrin.h>
#include <chrono>
#include <cmath>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
// ------------------------------------------------------------
// CPU 特性检测
// ------------------------------------------------------------
static bool os_supports_ymm_state() {
int cpu_info[4] = {};
__cpuid(cpu_info, 1);
const bool osxsave = (cpu_info[2] & (1 << 27)) != 0;
const bool avx = (cpu_info[2] & (1 << 28)) != 0;
if (!osxsave || !avx) {
return false;
}
const unsigned long long xcr0 = _xgetbv(0);
return (xcr0 & 0x6) == 0x6;
}
static bool cpu_supports_avx2() {
if (!os_supports_ymm_state()) {
return false;
}
int cpu_info[4] = {};
__cpuidex(cpu_info, 7, 0);
return (cpu_info[1] & (1 << 5)) != 0;
}
static bool cpu_supports_avx_vnni() {
if (!os_supports_ymm_state()) {
return false;
}
int cpu_info[4] = {};
__cpuidex(cpu_info, 7, 1);
return (cpu_info[0] & (1 << 4)) != 0;
}
// ------------------------------------------------------------
// 点积计算内核(数学定义一致):
// sum = Σ (a[i] * b[i]),其中 a 为 u8,b 为 s8。
// ------------------------------------------------------------
// 串行标量参考实现(单累加器)。
// 该版本故意保持"最朴素串行",便于理解与对照。
static int64_t scalar_dot_u8s8_serial(const uint8_t* a, const int8_t* b, int n) {
int64_t sum = 0;
for (int i = 0; i < n; ++i) {
sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
}
return sum;
}
// 4 路累加器标量实现(减少依赖链长度)。
static int64_t scalar_dot_u8s8(const uint8_t* a, const int8_t* b, int n) {
int64_t s0 = 0, s1 = 0, s2 = 0, s3 = 0;
int i = 0;
for (; i + 3 < n; i += 4) {
s0 += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
s1 += static_cast<int32_t>(a[i + 1]) * static_cast<int32_t>(b[i + 1]);
s2 += static_cast<int32_t>(a[i + 2]) * static_cast<int32_t>(b[i + 2]);
s3 += static_cast<int32_t>(a[i + 3]) * static_cast<int32_t>(b[i + 3]);
}
int64_t sum = s0 + s1 + s2 + s3;
for (; i < n; ++i) {
sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
}
return sum;
}
// AVX2 实现:
// 1) _mm256_maddubs_epi16:按字节做 u8*s8,并得到打包的 s16 部分和
// 2) _mm256_madd_epi16:将相邻 s16 与常量 1 做乘加,归并为 s32
// 使用 4 个相互独立的累加器提升 ILP(指令级并行度)。
static int64_t avx2_dot_u8s8(const uint8_t* a, const int8_t* b, int n) {
__m256i acc0 = _mm256_setzero_si256();
__m256i acc1 = _mm256_setzero_si256();
__m256i acc2 = _mm256_setzero_si256();
__m256i acc3 = _mm256_setzero_si256();
const __m256i ones16 = _mm256_set1_epi16(1);
int i = 0;
for (; i + 127 < n; i += 128) {
const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
const __m256i p0 = _mm256_maddubs_epi16(va, vb);
acc0 = _mm256_add_epi32(acc0, _mm256_madd_epi16(p0, ones16));
const __m256i va1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 32));
const __m256i vb1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 32));
const __m256i p1 = _mm256_maddubs_epi16(va1, vb1);
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(p1, ones16));
const __m256i va2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 64));
const __m256i vb2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 64));
const __m256i p2 = _mm256_maddubs_epi16(va2, vb2);
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(p2, ones16));
const __m256i va3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 96));
const __m256i vb3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 96));
const __m256i p3 = _mm256_maddubs_epi16(va3, vb3);
acc3 = _mm256_add_epi32(acc3, _mm256_madd_epi16(p3, ones16));
}
for (; i + 31 < n; i += 32) {
const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
const __m256i p = _mm256_maddubs_epi16(va, vb);
acc0 = _mm256_add_epi32(acc0, _mm256_madd_epi16(p, ones16));
}
const __m256i acc32 = _mm256_add_epi32(_mm256_add_epi32(acc0, acc1), _mm256_add_epi32(acc2, acc3));
alignas(32) int32_t lanes[8] = {};
_mm256_store_si256(reinterpret_cast<__m256i*>(lanes), acc32);
int64_t sum = 0;
for (int k = 0; k < 8; ++k) {
sum += lanes[k];
}
for (; i < n; ++i) {
sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
}
return sum;
}
#if defined(__AVXVNNI__)
// AVX-VNNI 实现:
// _mm256_dpbusd_avx_epi32 会按每 4 个字节做点积,并累加到 s32 lane。
// 同样使用 4 个独立累加器。
static int64_t avx_vnni_dot_u8s8(const uint8_t* a, const int8_t* b, int n) {
__m256i acc0 = _mm256_setzero_si256();
__m256i acc1 = _mm256_setzero_si256();
__m256i acc2 = _mm256_setzero_si256();
__m256i acc3 = _mm256_setzero_si256();
int i = 0;
for (; i + 127 < n; i += 128) {
const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
acc0 = _mm256_dpbusd_avx_epi32(acc0, va, vb);
const __m256i va1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 32));
const __m256i vb1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 32));
acc1 = _mm256_dpbusd_avx_epi32(acc1, va1, vb1);
const __m256i va2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 64));
const __m256i vb2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 64));
acc2 = _mm256_dpbusd_avx_epi32(acc2, va2, vb2);
const __m256i va3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i + 96));
const __m256i vb3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i + 96));
acc3 = _mm256_dpbusd_avx_epi32(acc3, va3, vb3);
}
for (; i + 31 < n; i += 32) {
const __m256i va = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a + i));
const __m256i vb = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b + i));
acc0 = _mm256_dpbusd_avx_epi32(acc0, va, vb);
}
const __m256i acc32 = _mm256_add_epi32(_mm256_add_epi32(acc0, acc1), _mm256_add_epi32(acc2, acc3));
alignas(32) int32_t lanes[8] = {};
_mm256_store_si256(reinterpret_cast<__m256i*>(lanes), acc32);
int64_t sum = 0;
for (int k = 0; k < 8; ++k) {
sum += lanes[k];
}
for (; i < n; ++i) {
sum += static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
}
return sum;
}
#endif
// AVX2+FMA(float)示例:y[i] = a[i] * b[i] + c[i],并返回 sum(y) 用于防止优化。
static float avx2_fma_f32_fma_sum(const float* a, const float* b, const float* c, int n) {
__m256 acc0 = _mm256_setzero_ps();
__m256 acc1 = _mm256_setzero_ps();
__m256 acc2 = _mm256_setzero_ps();
__m256 acc3 = _mm256_setzero_ps();
int i = 0;
for (; i + 31 < n; i += 32) {
const __m256 a0 = _mm256_loadu_ps(a + i);
const __m256 b0 = _mm256_loadu_ps(b + i);
const __m256 c0 = _mm256_loadu_ps(c + i);
acc0 = _mm256_add_ps(acc0, _mm256_fmadd_ps(a0, b0, c0));
const __m256 a1 = _mm256_loadu_ps(a + i + 8);
const __m256 b1 = _mm256_loadu_ps(b + i + 8);
const __m256 c1 = _mm256_loadu_ps(c + i + 8);
acc1 = _mm256_add_ps(acc1, _mm256_fmadd_ps(a1, b1, c1));
const __m256 a2 = _mm256_loadu_ps(a + i + 16);
const __m256 b2 = _mm256_loadu_ps(b + i + 16);
const __m256 c2 = _mm256_loadu_ps(c + i + 16);
acc2 = _mm256_add_ps(acc2, _mm256_fmadd_ps(a2, b2, c2));
const __m256 a3 = _mm256_loadu_ps(a + i + 24);
const __m256 b3 = _mm256_loadu_ps(b + i + 24);
const __m256 c3 = _mm256_loadu_ps(c + i + 24);
acc3 = _mm256_add_ps(acc3, _mm256_fmadd_ps(a3, b3, c3));
}
__m256 acc = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
alignas(32) float lanes[8] = {};
_mm256_store_ps(lanes, acc);
float sum = 0.0f;
for (int k = 0; k < 8; ++k) {
sum += lanes[k];
}
for (; i < n; ++i) {
sum += a[i] * b[i] + c[i];
}
return sum;
}
// 标量参考:float 路径 sum(a[i] * b[i] + c[i]),用于 mixed 校验。
static float scalar_fma_f32_sum(const float* a, const float* b, const float* c, int n) {
float sum = 0.0f;
for (int i = 0; i < n; ++i) {
sum += a[i] * b[i] + c[i];
}
return sum;
}
// 最小教学示例:用写死的 8 个 float,逐步解释 AVX2+FMA 在做什么。
static void explain_avx2_fma_minimal_demo() {
// 固定示例数组(长度 8,刚好 1 个 256-bit 向量)。
alignas(32) const float a[8] = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f };
alignas(32) const float b[8] = { 0.5f, -1.0f, 1.5f, 2.0f, -0.5f, 3.0f, 0.25f, -2.0f };
alignas(32) const float c[8] = { 0.1f, -0.7f, 2.2f, 1.3f, -1.1f, 0.6f, 3.0f, -0.4f };
std::cout << "\n=== AVX2+FMA Minimal Demo (8 elements) ===" << std::endl;
std::cout << "Target: y[i] = a[i] * b[i] + c[i]" << std::endl;
std::cout << "a = [1,2,3,4,5,6,7,8]" << std::endl;
std::cout << "b = [0.5,-1,1.5,2,-0.5,3,0.25,-2]" << std::endl;
std::cout << "c = [0.1,-0.7,2.2,1.3,-1.1,0.6,3,-0.4]" << std::endl;
// 第 1 步:标量逐项计算(作为"人能看懂"的参考过程)。
alignas(32) float scalar_y[8] = {};
float scalar_sum = 0.0f;
for (int i = 0; i < 8; ++i) {
scalar_y[i] = a[i] * b[i] + c[i];
scalar_sum += scalar_y[i];
std::cout << "Scalar step " << i
<< ": " << a[i] << " * " << b[i] << " + " << c[i]
<< " = " << scalar_y[i] << std::endl;
}
// 第 2 步:把 8 个数一次性加载到 256-bit 向量寄存器。
const __m256 va = _mm256_load_ps(a);
const __m256 vb = _mm256_load_ps(b);
const __m256 vc = _mm256_load_ps(c);
// 第 3 步:一条 FMA 指令同时完成 8 个 lane 的乘加。
// 数学上等价于:vy[lane] = va[lane] * vb[lane] + vc[lane]
const __m256 vy = _mm256_fmadd_ps(va, vb, vc);
// 第 4 步:把向量结果写回数组,便于观察每个 lane 的值。
alignas(32) float avx_y[8] = {};
_mm256_store_ps(avx_y, vy);
float avx_sum = 0.0f;
for (int i = 0; i < 8; ++i) {
avx_sum += avx_y[i];
std::cout << "AVX lane " << i << ": "
<< a[i] << " * " << b[i] << " + " << c[i]
<< " = " << avx_y[i] << std::endl;
}
std::cout << "Scalar sum = " << scalar_sum << ", AVX sum = " << avx_sum
<< ", abs_err = " << std::abs(scalar_sum - avx_sum) << std::endl;
}
#if defined(__AVXVNNI__)
// 最小教学示例:用写死的 32 组 u8/s8,逐步解释 AVX-VNNI 在做什么。
static void explain_avx_vnni_minimal_demo() {
// 直接硬编码 32 个元素,便于逐项观察标量与 VNNI 的对应关系。
alignas(32) const uint8_t a[32] = {
1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32
};
alignas(32) const int8_t b[32] = {
-3, -2, -1, 0, 1, 2, 3, -3,
-2, -1, 0, 1, 2, 3, -3, -2,
-1, 0, 1, 2, 3, -3, -2, -1,
0, 1, 2, 3, -3, -2, -1, 0
};
std::cout << "\n=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===" << std::endl;
std::cout << "Target: dot = sum(a[i] * b[i])" << std::endl;
std::cout << "a = [1,2,3,...,32]" << std::endl;
std::cout << "b = [-3,-2,-1,0,1,2,3,-3,...]" << std::endl;
// 第 1 步:标量参考,逐项打印。
int32_t scalar_sum = 0;
for (int i = 0; i < 32; ++i) {
const int32_t term = static_cast<int32_t>(a[i]) * static_cast<int32_t>(b[i]);
scalar_sum += term;
std::cout << "Scalar step " << i
<< ": " << static_cast<int32_t>(a[i]) << " * " << static_cast<int32_t>(b[i])
<< " = " << term << ", running_sum = " << scalar_sum << std::endl;
}
// 第 2 步:向量加载 32 个字节(a 和 b 各 32 个)。
const __m256i va = _mm256_load_si256(reinterpret_cast<const __m256i*>(a));
const __m256i vb = _mm256_load_si256(reinterpret_cast<const __m256i*>(b));
// 第 3 步:VNNI 核心指令。
// acc[lane] += a[4*lane+0]*b[4*lane+0] + ... + a[4*lane+3]*b[4*lane+3]
__m256i acc = _mm256_setzero_si256();
acc = _mm256_dpbusd_avx_epi32(acc, va, vb);
// 第 4 步:观察 8 个 lane 的部分和,再做总和。
alignas(32) int32_t lanes[8] = {};
_mm256_store_si256(reinterpret_cast<__m256i*>(lanes), acc);
int32_t vnni_sum = 0;
for (int lane = 0; lane < 8; ++lane) {
vnni_sum += lanes[lane];
std::cout << "VNNI lane " << lane << " partial_sum = " << lanes[lane] << std::endl;
}
std::cout << "Scalar sum = " << scalar_sum
<< ", VNNI sum = " << vnni_sum
<< ", abs_err = " << std::abs(scalar_sum - vnni_sum) << std::endl;
}
#endif
#if defined(__AVXVNNI__)
// 混合计算:单轮里同时执行 AVX2+FMA(float)与 AVX-VNNI(u8/s8)。
static std::pair<float, int64_t> mixed_avx2_fma_vnni(
const float* fa, const float* fb, const float* fc, int fn,
const uint8_t* ia, const int8_t* ib, int in) {
const float fsum = avx2_fma_f32_fma_sum(fa, fb, fc, fn);
const int64_t isum = avx_vnni_dot_u8s8(ia, ib, in);
return { fsum, isum };
}
#endif
template <typename Func>
static double benchmark_us(Func&& fn, int iterations, volatile int64_t& sink) {
const auto begin = std::chrono::high_resolution_clock::now();
for (int i = 0; i < iterations; ++i) {
sink ^= fn();
}
const auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double, std::micro>(end - begin).count();
}
struct BenchConfig {
std::string name;
int effective_n;
int warmup;
int iterations;
bool run_simd;
};
static void print_method_result(
const std::string& method,
int64_t result,
int64_t abs_err,
double rel_err,
double total_us,
int iterations,
double speedup) {
std::cout << "[" << method << "] result = " << result
<< ", abs_err = " << abs_err
<< ", rel_err = " << rel_err
<< ", total = " << total_us << " us"
<< ", per pass = " << (total_us / iterations) << " us";
if (speedup > 0.0) {
std::cout << ", speedup = " << speedup << "x";
}
std::cout << std::endl;
}
static void run_bench_suite(const BenchConfig& cfg, const std::vector<uint8_t>& a, const std::vector<int8_t>& b) {
volatile int64_t sink = 0;
const uint8_t* pa = a.data();
const int8_t* pb = b.data();
// 本组测试中,以"4 路累加器标量实现"作为 SIMD 对比基线。
for (int i = 0; i < cfg.warmup; ++i) {
sink ^= scalar_dot_u8s8(pa, pb, cfg.effective_n);
}
const double scalar_us = benchmark_us([&]() { return scalar_dot_u8s8(pa, pb, cfg.effective_n); }, cfg.iterations, sink);
const int64_t scalar_res = scalar_dot_u8s8(pa, pb, cfg.effective_n);
std::cout << "\n=== " << cfg.name << " ===" << std::endl;
std::cout << "effective_n = " << cfg.effective_n
<< ", warmup = " << cfg.warmup
<< ", iterations = " << cfg.iterations << std::endl;
print_method_result("Scalar", scalar_res, 0, 0.0, scalar_us, cfg.iterations, 0.0);
if (!cfg.run_simd) {
return;
}
if (cpu_supports_avx2()) {
for (int i = 0; i < cfg.warmup; ++i) {
sink ^= avx2_dot_u8s8(pa, pb, cfg.effective_n);
}
const double avx2_us = benchmark_us([&]() { return avx2_dot_u8s8(pa, pb, cfg.effective_n); }, cfg.iterations, sink);
const int64_t avx2_res = avx2_dot_u8s8(pa, pb, cfg.effective_n);
const int64_t err = avx2_res - scalar_res;
const double rel_err = (scalar_res == 0) ? 0.0 : std::abs(static_cast<double>(err) / static_cast<double>(scalar_res));
print_method_result("AVX2 ", avx2_res, err, rel_err, avx2_us, cfg.iterations, scalar_us / avx2_us);
}
else {
std::cout << "[AVX2 ] unsupported at runtime." << std::endl;
}
#if defined(__AVXVNNI__)
if (cpu_supports_avx_vnni()) {
for (int i = 0; i < cfg.warmup; ++i) {
sink ^= avx_vnni_dot_u8s8(pa, pb, cfg.effective_n);
}
const double vnni_us = benchmark_us([&]() { return avx_vnni_dot_u8s8(pa, pb, cfg.effective_n); }, cfg.iterations, sink);
const int64_t vnni_res = avx_vnni_dot_u8s8(pa, pb, cfg.effective_n);
const int64_t err = vnni_res - scalar_res;
const double rel_err = (scalar_res == 0) ? 0.0 : std::abs(static_cast<double>(err) / static_cast<double>(scalar_res));
print_method_result("VNNI ", vnni_res, err, rel_err, vnni_us, cfg.iterations, scalar_us / vnni_us);
}
else {
std::cout << "[VNNI ] runtime unsupported." << std::endl;
}
#else
std::cout << "[VNNI ] not compiled with AVX-VNNI intrinsics." << std::endl;
#endif
}
int main() {
int n = 0;
std::cout << "Input n (u8/s8 vector length): ";
if (!(std::cin >> n) || n <= 0) {
std::cerr << "Invalid input." << std::endl;
return 1;
}
// 演示用数组说明:
// a[i](u8):0..250 循环
// b[i](s8):-63..63 循环
// 前几个元素示意:
// a = [0,1,2,3,4,...]
// b = [-63,-62,-61,-60,-59,...]
// 所有测试组复用同一份输入数据,确保横向对比公平。
std::vector<uint8_t> a(n);
std::vector<int8_t> b(n);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<uint8_t>(i % 251);
b[i] = static_cast<int8_t>((i % 127) - 63);
}
std::cout << std::fixed << std::setprecision(6);
std::cout << "n(input) = " << n << std::endl;
// 先跑一个最小教学示例,帮助理解 AVX2+FMA 的逐步计算过程。
if (cpu_supports_avx2()) {
explain_avx2_fma_minimal_demo();
}
else {
std::cout << "\n=== AVX2+FMA Minimal Demo (8 elements) ===" << std::endl;
std::cout << "Current CPU/OS does not support AVX2, demo skipped." << std::endl;
}
#if defined(__AVXVNNI__)
if (cpu_supports_avx_vnni()) {
explain_avx_vnni_minimal_demo();
}
else {
std::cout << "\n=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===" << std::endl;
std::cout << "Current CPU/OS does not support AVX-VNNI, demo skipped." << std::endl;
}
#else
std::cout << "\n=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===" << std::endl;
std::cout << "This build does not enable AVX-VNNI intrinsics, demo skipped." << std::endl;
#endif
// 1) 串行基线测试:
// 仅使用单累加器标量内核。
{
volatile int64_t sink = 0;
int serial_iters = 2000;
if (n <= 256) {
serial_iters = 50000;
}
else if (n <= 4096) {
serial_iters = 10000;
}
const int serial_warmup = 100;
for (int i = 0; i < serial_warmup; ++i) {
sink ^= scalar_dot_u8s8_serial(a.data(), b.data(), n);
}
const double serial_us = benchmark_us([&]() { return scalar_dot_u8s8_serial(a.data(), b.data(), n); }, serial_iters, sink);
const int64_t serial_res = scalar_dot_u8s8_serial(a.data(), b.data(), n);
std::cout << "\n=== Serial (single accumulator scalar) ===" << std::endl;
std::cout << "effective_n = " << n
<< ", warmup = " << serial_warmup
<< ", iterations = " << serial_iters << std::endl;
print_method_result("Serial", serial_res, 0, 0.0, serial_us, serial_iters, 0.0);
}
// 2) Streaming 测试:
// 使用完整 n,更接近真实线性内存流式访问场景。
{
BenchConfig cfg{};
cfg.name = "Streaming (full input size)";
cfg.effective_n = n;
cfg.warmup = 100;
cfg.iterations = 2000;
if (n <= 256) cfg.iterations = 50000;
else if (n <= 4096) cfg.iterations = 10000;
cfg.run_simd = true;
run_bench_suite(cfg, a, b);
}
// 3) L1/L2-friendly 测试:
// 将 effective_n 限制为 <=4096,使数据更易驻留缓存,突出计算吞吐能力。
{
BenchConfig cfg{};
cfg.name = "L1/L2-friendly (cache-resident working set)";
cfg.effective_n = (n < 4096) ? n : 4096;
cfg.warmup = 500;
cfg.iterations = 80000;
cfg.run_simd = true;
run_bench_suite(cfg, a, b);
}
// 4) 混合压榨测试(AVX2+FMA + VNNI):
// 在同一轮中同时执行 float FMA 和 int8 VNNI,观察混合算子吞吐。
if (cpu_supports_avx2()
#if defined(__AVXVNNI__)
&& cpu_supports_avx_vnni()
#else
&& false
#endif
) {
// 这里改为与主测试一致:float 与 int 路径都使用同一个 n。
const int f_n = n;
const int i_n = n;
std::vector<float> fa(f_n), fb(f_n), fc(f_n);
for (int i = 0; i < f_n; ++i) {
fa[i] = static_cast<float>((i % 113) * 0.03125f + 1.0f);
fb[i] = static_cast<float>((i % 97) * 0.015625f + 0.5f);
fc[i] = static_cast<float>((i % 89) * 0.0078125f + 0.25f);
}
int warmup = 100;
int iters = 2000;
if (n <= 256) {
iters = 50000;
}
else if (n <= 4096) {
iters = 10000;
}
volatile int64_t mix_sink = 0;
volatile float mix_f_sink = 0.0f;
#if defined(__AVXVNNI__)
for (int i = 0; i < warmup; ++i) {
const auto r = mixed_avx2_fma_vnni(fa.data(), fb.data(), fc.data(), f_n, a.data(), b.data(), i_n);
mix_f_sink += r.first;
mix_sink ^= r.second;
}
const auto begin = std::chrono::high_resolution_clock::now();
for (int i = 0; i < iters; ++i) {
const auto r = mixed_avx2_fma_vnni(fa.data(), fb.data(), fc.data(), f_n, a.data(), b.data(), i_n);
mix_f_sink += r.first;
mix_sink ^= r.second;
}
const auto end = std::chrono::high_resolution_clock::now();
const double mix_us = std::chrono::duration<double, std::micro>(end - begin).count();
const auto final_r = mixed_avx2_fma_vnni(fa.data(), fb.data(), fc.data(), f_n, a.data(), b.data(), i_n);
const float float_scalar_ref = scalar_fma_f32_sum(fa.data(), fb.data(), fc.data(), f_n);
const int64_t int_scalar_ref = scalar_dot_u8s8_serial(a.data(), b.data(), i_n);
const float float_abs_err = std::abs(final_r.first - float_scalar_ref);
const int64_t int_abs_err = final_r.second - int_scalar_ref;
std::cout << "\n=== Mixed Throughput (AVX2+FMA + VNNI) ===" << std::endl;
std::cout << "float_n = " << f_n << ", int_n = " << i_n
<< ", warmup = " << warmup << ", iterations = " << iters << std::endl;
std::cout << "[Mixed] float_result = " << final_r.first
<< ", int_result = " << final_r.second
<< ", total = " << mix_us << " us"
<< ", per pass = " << (mix_us / iters) << " us" << std::endl;
std::cout << "[Mixed-Check] float_scalar_ref = " << float_scalar_ref
<< ", float_abs_err = " << float_abs_err << std::endl;
std::cout << "[Mixed-Check] int_scalar_ref = " << int_scalar_ref
<< ", int_abs_err = " << int_abs_err << std::endl;
#endif
(void)mix_sink;
(void)mix_f_sink;
}
else {
std::cout << "\n=== Mixed Throughput (AVX2+FMA + VNNI) ===" << std::endl;
std::cout << "[Mixed] skipped: requires AVX2 and AVX-VNNI runtime/compiler support." << std::endl;
}
return 0;
}
运行结果可以看到,内核中的特定加速指令集确实速度会快很多
Input n (u8/s8 vector length): 1024
n(input) = 1024
=== AVX2+FMA Minimal Demo (8 elements) ===
Target: yi = ai * bi + ci
a = 1,2,3,4,5,6,7,8
b = 0.5,-1,1.5,2,-0.5,3,0.25,-2
c = 0.1,-0.7,2.2,1.3,-1.1,0.6,3,-0.4
Scalar step 0: 1.000000 * 0.500000 + 0.100000 = 0.600000
Scalar step 1: 2.000000 * -1.000000 + -0.700000 = -2.700000
Scalar step 2: 3.000000 * 1.500000 + 2.200000 = 6.700000
Scalar step 3: 4.000000 * 2.000000 + 1.300000 = 9.300000
Scalar step 4: 5.000000 * -0.500000 + -1.100000 = -3.600000
Scalar step 5: 6.000000 * 3.000000 + 0.600000 = 18.600000
Scalar step 6: 7.000000 * 0.250000 + 3.000000 = 4.750000
Scalar step 7: 8.000000 * -2.000000 + -0.400000 = -16.400000
AVX lane 0: 1.000000 * 0.500000 + 0.100000 = 0.600000
AVX lane 1: 2.000000 * -1.000000 + -0.700000 = -2.700000
AVX lane 2: 3.000000 * 1.500000 + 2.200000 = 6.700000
AVX lane 3: 4.000000 * 2.000000 + 1.300000 = 9.300000
AVX lane 4: 5.000000 * -0.500000 + -1.100000 = -3.600000
AVX lane 5: 6.000000 * 3.000000 + 0.600000 = 18.600000
AVX lane 6: 7.000000 * 0.250000 + 3.000000 = 4.750000
AVX lane 7: 8.000000 * -2.000000 + -0.400000 = -16.400000
Scalar sum = 17.250002, AVX sum = 17.250002, abs_err = 0.000000
=== AVX-VNNI Minimal Demo (32 u8/s8 elements) ===
Target: dot = sum(ai * bi)
a = 1,2,3,...,32
b = -3,-2,-1,0,1,2,3,-3,...
Scalar step 0: 1 * -3 = -3, running_sum = -3
Scalar step 1: 2 * -2 = -4, running_sum = -7
Scalar step 2: 3 * -1 = -3, running_sum = -10
Scalar step 3: 4 * 0 = 0, running_sum = -10
Scalar step 4: 5 * 1 = 5, running_sum = -5
Scalar step 5: 6 * 2 = 12, running_sum = 7
Scalar step 6: 7 * 3 = 21, running_sum = 28
Scalar step 7: 8 * -3 = -24, running_sum = 4
Scalar step 8: 9 * -2 = -18, running_sum = -14
Scalar step 9: 10 * -1 = -10, running_sum = -24
Scalar step 10: 11 * 0 = 0, running_sum = -24
Scalar step 11: 12 * 1 = 12, running_sum = -12
Scalar step 12: 13 * 2 = 26, running_sum = 14
Scalar step 13: 14 * 3 = 42, running_sum = 56
Scalar step 14: 15 * -3 = -45, running_sum = 11
Scalar step 15: 16 * -2 = -32, running_sum = -21
Scalar step 16: 17 * -1 = -17, running_sum = -38
Scalar step 17: 18 * 0 = 0, running_sum = -38
Scalar step 18: 19 * 1 = 19, running_sum = -19
Scalar step 19: 20 * 2 = 40, running_sum = 21
Scalar step 20: 21 * 3 = 63, running_sum = 84
Scalar step 21: 22 * -3 = -66, running_sum = 18
Scalar step 22: 23 * -2 = -46, running_sum = -28
Scalar step 23: 24 * -1 = -24, running_sum = -52
Scalar step 24: 25 * 0 = 0, running_sum = -52
Scalar step 25: 26 * 1 = 26, running_sum = -26
Scalar step 26: 27 * 2 = 54, running_sum = 28
Scalar step 27: 28 * 3 = 84, running_sum = 112
Scalar step 28: 29 * -3 = -87, running_sum = 25
Scalar step 29: 30 * -2 = -60, running_sum = -35
Scalar step 30: 31 * -1 = -31, running_sum = -66
Scalar step 31: 32 * 0 = 0, running_sum = -66
VNNI lane 0 partial_sum = -10
VNNI lane 1 partial_sum = 14
VNNI lane 2 partial_sum = -16
VNNI lane 3 partial_sum = -9
VNNI lane 4 partial_sum = 42
VNNI lane 5 partial_sum = -73
VNNI lane 6 partial_sum = 164
VNNI lane 7 partial_sum = -178
Scalar sum = -66, VNNI sum = -66, abs_err = 0
=== Serial (single accumulator scalar) ===
effective_n = 1024, warmup = 100, iterations = 10000
Serial result = 913898, abs_err = 0, rel_err = 0.000000, total = 7224.700000 us, per pass = 0.722470 us
=== Streaming (full input size) ===
effective_n = 1024, warmup = 100, iterations = 10000
Scalar result = 913898, abs_err = 0, rel_err = 0.000000, total = 6741.500000 us, per pass = 0.674150 us
AVX2 result = 913898, abs_err = 0, rel_err = 0.000000, total = 1121.200000 us, per pass = 0.112120 us, speedup = 6.012754x
VNNI result = 913898, abs_err = 0, rel_err = 0.000000, total = 780.200000 us, per pass = 0.078020 us, speedup = 8.640733x
=== L1/L2-friendly (cache-resident working set) ===
effective_n = 1024, warmup = 500, iterations = 80000
Scalar result = 913898, abs_err = 0, rel_err = 0.000000, total = 51868.200000 us, per pass = 0.648352 us
AVX2 result = 913898, abs_err = 0, rel_err = 0.000000, total = 8255.100000 us, per pass = 0.103189 us, speedup = 6.283170x
VNNI result = 913898, abs_err = 0, rel_err = 0.000000, total = 6119.500000 us, per pass = 0.076494 us, speedup = 8.475889x
=== Mixed Throughput (AVX2+FMA + VNNI) ===
float_n = 1024, int_n = 1024, warmup = 100, iterations = 10000
Mixed float_result = 4065.320801, int_result = 913898, total = 4289.400000 us, per pass = 0.428940 us
Mixed-Check float_scalar_ref = 4065.320801, float_abs_err = 0.000000
Mixed-Check int_scalar_ref = 913898, int_abs_err = 0