int8_to_float(output_tensor->data.int8, output_float, load_class_num);

#include "tensorflow/lite/core/c/common.h"

#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"

#include "tensorflow/lite/micro/micro_interpreter.h"

#include "tensorflow/lite/micro/micro_log.h"

#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"

#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"

#include "tensorflow/lite/micro/system_setup.h"

#include "tensorflow/lite/schema/schema_generated.h"

// 外部模型数据(需在其他文件中定义)

extern const char g_model_data[];

extern const int g_model_data_len;

// 内存分配区,32字节对齐
attribute ((section(".tensor_arena"), aligned(32)))

uint8_t g_tensor_arena[32 * 1024] attribute ((aligned(32))) = {0};

extern const int g_tensor_arena_len;

// 量化参数(必须与 Python 训练时保持一致)

const float input_scale = 0.05299549177289009f;

const int32_t input_zero_point = 4;

const float output_scale = 0.00390625f;

const int32_t output_zero_point = -128;

// 其他声明

const int kMaxOps = 13; // 支持的操作数

const int load_class_num = 5; // 分类数量

namespace {

using OpResolver = tflite::MicroMutableOpResolver;

}

// 算法上下文结构体

struct ModelContext {

OpResolver resolver;

tflite::MicroInterpreter* interpreter;

};

// 内存池管理

class MemoryPool {

public:

void* Allocate(size_t size) {

if (current_position_ + size <= pool_size_) {

void* ptr = &pool_[current_position_];

current_position_ += size;

return ptr;

} else {

return nullptr;

}

}

复制代码
void Reset() { current_position_ = 0; }

private:

static constexpr size_t pool_size_ = 64 * 1024; // 内存池大小

uint8_t pool_[pool_size_]; // 内存池

size_t current_position_ = 0;

};

// 内部自定义函数

int n1m_argmax(const float* data, int size);

void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size);

void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size);

// 注册支持的算子

TFliteStatus RegisterOps(OpResolver& resolver) {

TF_LITE_ENSURE_STATUS(resolver.AddFullyConnected());

TF_LITE_ENSURE_STATUS(resolver.AddConv2D());

TF_LITE_ENSURE_STATUS(resolver.AddDepthwiseConv2D());

TF_LITE_ENSURE_STATUS(resolver.AddMaxPool2D());

TF_LITE_ENSURE_STATUS(resolver.AddAveragePool2D());

TF_LITE_ENSURE_STATUS(resolver.AddSoftmax());

TF_LITE_ENSURE_STATUS(resolver.AddReshape());

TF_LITE_ENSURE_STATUS(resolver.AddExpandDims());

TF_LITE_ENSURE_STATUS(resolver.AddMean());

TF_LITE_ENSURE_STATUS(resolver.AddShape());

TF_LITE_ENSURE_STATUS(resolver.AddStridedSlice());

TF_LITE_ENSURE_STATUS(resolver.AddPack());

TF_LITE_ENSURE_STATUS(resolver.AddDequantize());

复制代码
return kTfLiteOk;

}

// 输入浮点转 int8(量化)

void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size) {

for (int i = 0; i < size; ++i) {

float scaled = roundf(input_f32[i] / input_scale + input_zero_point);

int32_t q;

if (scaled > 127.0f) {

q = 127;

} else if (scaled < -128.0f) {

q = -128;

} else {

q = static_cast<int32_t>(scaled);

}

output_i8[i] = static_cast<int8_t>(q);

}

}

// 输出 int8 转浮点(反量化)

void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size) {

for (int i = 0; i < size; ++i) {

output_f32[i] = (input_i8[i] - output_zero_point) * output_scale;

}

}

// 找最大值索引(argmax)

int n1m_argmax(const float* data, int num_classes) {

int max_index = 0;

float max_value = data[0];

for (int i = 1; i < num_classes; ++i) {

if (data[i] > max_value) {

max_value = data[i];

max_index = i;

}

}

return max_index;

}

// 主要推理函数

extern "C" int CPGC_ALG_NILM_RUN_T1(const float* VI_input, const int data_len) {

MemoryPool memory_pool; // 使用内存池

复制代码
// 初始化上下文
ModelContext* ctx = static_cast<ModelContext*>(memory_pool.Allocate(sizeof(ModelContext)));
if (!ctx) {
    MicroPrintf("Failed to allocate ModelContext");
    return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}

memset(g_tensor_arena, 0, sizeof(g_tensor_arena));

// 加载模型
const tflite::Model* model = tflite::GetModel(g_model_data);
if (!model) {
    MicroPrintf("Failed to load model - GetModel returned null");
    return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}

if (model->version() != TFLITE_SCHEMA_VERSION) {
    MicroPrintf("Model version mismatch: got %d, expected %d", model->version(), TFLITE_SCHEMA_VERSION);
    return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}

// 创建解析器并注册操作
OpResolver resolver;
TFliteStatus resolver_status = RegisterOps(resolver);
if (resolver_status != kTfLiteOk) {
    MicroPrintf("RegisterOps failed: %d", resolver_status);
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 创建解释器
ctx->interpreter = new (memory_pool.Allocate(sizeof(tflite::MicroInterpreter))) tflite::MicroInterpreter(
    model,
    resolver,
    g_tensor_arena,
    sizeof(g_tensor_arena)
);

if (ctx->interpreter == nullptr) {
    MicroPrintf("Failed to create interpreter - out of memory");
    return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}

// 分配张量
TFliteStatus status = ctx->interpreter->AllocateTensors();
if (status != kTfLiteOk) {
    MicroPrintf("AllocateTensors failed: %d", status);
    MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

MicroPrintf("Model loaded successfully");
MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());

// 获取输入 tensor
TfLiteTensor* input_tensor = ctx->interpreter->input(0);
if (!input_tensor || !input_tensor->data.data) {
    MicroPrintf("ERROR: Failed to get input tensor");
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 量化输入
int8_t quantized_input[256] = {0};
n1m_input_float_to_int8(VI_input, quantized_input, data_len);
memcpy(input_tensor->data.int8, quantized_input, data_len);

// 推理
TFliteStatus invoke_status = ctx->interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
    MicroPrintf("Invoke failed: %d", invoke_status);
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 获取输出并反量化
TfLiteTensor* output_tensor = ctx->interpreter->output(0);
if (!output_tensor || !output_tensor->data.data) {
    MicroPrintf("ERROR: Failed to get output tensor");
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 打印原始 int8 输出
MicroPrintf("Raw int8 output (first 5):");
for (int i = 0; i < load_class_num; ++i) {
    MicroPrintf(" [%d]", output_tensor->data.int8[i]);
}

// 反量化为 float
float output_float[10];  // 假设输出大小 <=10
n1m_output_int8_to_float(output_tensor->data.int8, output_float, load_class_num);

MicroPrintf("Float output after dequantize (first 5):");
for (int i = 0; i < load_class_num; ++i) {
    MicroPrintf(" [%d]: %.6f", i, output_float[i]);
}

// 获取预测类别
int pred_label = n1m_argmax(output_float, load_class_num);
MicroPrintf("Prediction: Class %d", pred_label);

return pred_label;

}

相关推荐
喜欢打篮球的普通人2 天前
MLIR快速入门
neo4j·mlir
ELI_He9992 天前
Neo4j 安装 APOC
neo4j
綮地3 天前
Neo4j 基本处理
neo4j
lzp07913 天前
Neo4j图数据库学习(二)——SpringBoot整合Neo4j
数据库·学习·neo4j
爱折腾的小码农3 天前
neo4j数据库桌面管理工具
数据库·neo4j
Wenhao.7 天前
Docker 安装 neo4j
docker·容器·neo4j
RDCJM8 天前
Neo4j图数据库学习(二)——SpringBoot整合Neo4j
数据库·学习·neo4j
机器不学习我也不学习10 天前
TensorFlow环境安装
neo4j
码农老李11 天前
vxWorks7.0 Simpc运行tensorflow lite example
人工智能·tensorflow·neo4j
小鸡吃米…1 个月前
TensorFlow 实现异或(XOR)运算
人工智能·python·tensorflow·neo4j