在STM32上跑AI推理:TFLM从模型转换到部署的全流程解析

一个Cortex-M4内核的MCU,主频不到200MHz,RAM只有192KB------这样的硬件能跑AI吗?

三年前我可能摇头。但现在,TensorFlow Lite Micro (TFLM) 让这件事变得比想象中更直接。


先把模型塞进单片机

关键不在于"能不能跑",而在于"怎么跑"。PC上的神经网络动辄几百MB参数,FP16推理都嫌慢。但嵌入式的路子不一样------我们不需要大模型,我们需要的是刚好够用的模型

拿一个常见场景来说:在STM32F407上做关键词唤醒(KWS),检测"Hey STM32"这句唤醒词。完整的数据流是这样的:

复制代码
音频采样(16kHz) → 预处理(MFCC) → 模型推理 → 后处理 → 判定结果

这其中的推理部分,用的是一个只有12KB的depthwise separable卷积网络。你没看错,12KB。

先看模型怎么转换成单片机认得的东西。用TensorFlow训练好一个.h5.keras模型,然后走这样一条路:

复制代码
.h5 → TFLite Float → TFLite量化(Int8) → C数组(header file)

命令其实就几行:

复制代码
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_types = [tf.int8]
tflite_quant_model = converter.convert()

# 转成C数组
with open('model_data.cc', 'w') as f:
    f.write('#include "model_data.h"\n')
    f.write('const unsigned char model_data[] = {\n')
    for byte in tflite_quant_model:
        f.write(f'{byte},')
    f.write('};\n')

关键的一步是量化和 representative dataset。量化从FP32压到Int8,模型体积直接缩到四分之一,推理速度翻几倍。representative dataset 是校准用的------你给它几百条真实场景的音频片段,它就能算出最佳的量化参数,让精度损失控制在极小范围。

一个有意思的设计是,TFLM的推理完全不依赖操作系统。没有malloc,没有文件系统,没有任何POSIX调用。所有内存都在编译期静态分配。

单片机上的推理引擎怎么运作

TFLM在MCU上的运行时结构很清爽。它不搞动态加载那套,所有算子都在编译期注册好。我们用到的算子就三个:Conv2D、DepthwiseConv2D、FullyConnected,加上一些reshape和softmax。

初始化代码如下:

复制代码
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"

static tflite::MicroMutableOpResolver<6> resolver;
resolver.AddConv2D();
resolver.AddDepthwiseConv2D();
resolver.AddFullyConnected();
resolver.AddSoftmax();
resolver.AddReshape();
resolver.AddLogistic();

static tflite::MicroInterpreter static_interpreter(
    tflite::GetModel(model_data), resolver,
    tensor_arena, kTensorArenaSize);

kTensorArenaSize 怎么定?根据模型算。我们直接给了个足够大的值然后看实际占用------TFLM在初始化后会告诉你实际用了多少,然后可以回来调小。反复两三次就能压到最紧凑。

推理本身就是一个函数调用:

复制代码
// 把MFCC特征塞到输入张量
memcpy(interpreter->input(0)->data.int8, mfcc_features, kFeatureCount);

// 跑推理
interpreter->Invoke();

// 读输出
int8_t score = interpreter->output(0)->data.int8[0];

Invoke 走完一次大约花了28ms。在16kHz的音频流上,我们每帧取40ms的窗口,步进20ms,所以推理甚至比窗口滑动还快,实时性完全够。

让人意外的是arena分配

这里有一个常见误区:把tensor arena设得很大就觉得稳了。其实arena里不仅放中间张量,还放算子的临时buffer、scratch buffers。不同算子的内存需求差异非常大------DepthwiseConv2D的scratch buffer可能比它本身的输出张量还大几倍。

我们踩过这个坑。一开始arena给了24KB,跑float模型没问题,换int8量化模型反而报错------因为int8的DepthwiseConv2D用了更多scratch buffer做中间累加。解决方案是把arena加大到48KB,然后调用 interpreter->arena_used_bytes() 一看,实际只用了31KB。

所以正确的做法是:先给个大数字,跑一次看实际用量,再压到实际值的1.2倍左右留余量。别拍脑袋给个48KB然后代码里写死,也别抠抠搜搜给个刚好值然后换模型就崩。

MFCC预处理这一步别轻视

很多人把精力全放在模型上,预处理随便抄一段。但MFCC的参数会直接影响模型能不能收敛。我们的配置是这样的:

  • 采样率:16kHz
  • FFT窗口:512点(32ms)
  • 帧移:160点(10ms)
  • Mel滤波器组:40个
  • MFCC系数:13维
  • 每帧拼接上下文:前后各3帧 → 共7帧×13维 = 91维特征

这个配置不是随便选的。之所以用13维而不是常用的20维,是因为我们的模型第一个卷积层是3×3的depthwise,输入通道数直接决定了参数量。每减少7维,参数量少了15%左右,而实测准确率只掉了0.3%。这是个划算的交换。

预处理全部在MCU上用定点数做,不用浮点。FFT用CMSIS-DSP的arm_rfft_fast_f32,但输出转成Q15定点后再算Mel滤波器组------这样能避免每次FFT都走一遍浮点转定点的开销。

代码里有一个小技巧:Mel滤波器组的权重矩阵提前在PC上算好,存成uint16_t的查找表,单片机只做乘加移位,不现场算三角滤波器。

复制代码
static const uint16_t mel_filterbank[40][256] = {
    // PC预计算好的权重
};

static void extract_mfcc(const int16_t *audio_frame, int8_t *features) {
    // arm_rfft_fast_f32 跑FFT
    arm_rfft_fast_f32(&fft_instance, float_buf, float_buf, 0);

    // 定点乘加算mel能量
    for (int m = 0; m < 40; m++) {
        uint32_t energy = 0;
        for (int k = 0; k < 256; k++) {
            uint32_t mag_sq = (uint32_t)float_buf[k] * (uint32_t)float_buf[k];
            energy += (mag_sq * mel_filterbank[m][k]) >> 16;
        }
        mel_energy[m] = energy;
    }
    // ...log压缩和DCT省略
}

整段MFCC提取在STM32F407上跑一次大约8ms。加上28ms的推理,整个pipeline一帧处理耗时约36ms。帧移是20ms,所以CPU占用率大约180%------看起来超了?别忘了我们用ping-pong buffer+DMA双缓冲,MFCC处理一帧的同时,DMA在采下一帧,推理又在处理上一帧的MFCC结果。三级流水线一拉开,有效延迟只是单帧的处理时间。

讲讲输出怎么用

模型输出的一个int8值,范围是-128, 127。我们把它映射回概率:

复制代码
float confidence = (score - (-128)) / (127.0f - (-128.0f));
if (confidence > 0.7f) {
    // 检测到唤醒词
    trigger_wakeup();
}

0.7这个阈值不是随便定的。我们采集了200段非唤醒词的音频做负样本测试,发现误触发率在阈值0.5时是每天3.7次,提到0.7降到了每天0.4次,而唤醒率只从97%降到了94%。多数场景下少误触发比偶尔一次没唤醒更让人舒服。

如果你做的是安防类应用,阈值可以再往低调,宁可误报不可漏报。门禁场景则相反,误开一次门比漏开一次代价大得多。这是在部署阶段需要根据实际场景调的,训练时定不了。

这是今天想分享的内容。你的应用场景里,可靠性和灵敏度的天平往哪边倾斜?