int8_to_float(output_tensor->data.int8, output_float, load_class_num);

#include "tensorflow/lite/core/c/common.h"

#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"

#include "tensorflow/lite/micro/micro_interpreter.h"

#include "tensorflow/lite/micro/micro_log.h"

#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"

#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"

#include "tensorflow/lite/micro/system_setup.h"

#include "tensorflow/lite/schema/schema_generated.h"

// 外部模型数据(需在其他文件中定义)

extern const char g_model_data[];

extern const int g_model_data_len;

// 内存分配区,32字节对齐
attribute ((section(".tensor_arena"), aligned(32)))

uint8_t g_tensor_arena[32 * 1024] attribute ((aligned(32))) = {0};

extern const int g_tensor_arena_len;

// 量化参数(必须与 Python 训练时保持一致)

const float input_scale = 0.05299549177289009f;

const int32_t input_zero_point = 4;

const float output_scale = 0.00390625f;

const int32_t output_zero_point = -128;

// 其他声明

const int kMaxOps = 13; // 支持的操作数

const int load_class_num = 5; // 分类数量

namespace {

using OpResolver = tflite::MicroMutableOpResolver;

}

// 算法上下文结构体

struct ModelContext {

OpResolver resolver;

tflite::MicroInterpreter* interpreter;

};

// 内存池管理

class MemoryPool {

public:

void* Allocate(size_t size) {

if (current_position_ + size <= pool_size_) {

void* ptr = &pool_[current_position_];

current_position_ += size;

return ptr;

} else {

return nullptr;

}

}

复制代码
void Reset() { current_position_ = 0; }

private:

static constexpr size_t pool_size_ = 64 * 1024; // 内存池大小

uint8_t pool_[pool_size_]; // 内存池

size_t current_position_ = 0;

};

// 内部自定义函数

int n1m_argmax(const float* data, int size);

void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size);

void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size);

// 注册支持的算子

TFliteStatus RegisterOps(OpResolver& resolver) {

TF_LITE_ENSURE_STATUS(resolver.AddFullyConnected());

TF_LITE_ENSURE_STATUS(resolver.AddConv2D());

TF_LITE_ENSURE_STATUS(resolver.AddDepthwiseConv2D());

TF_LITE_ENSURE_STATUS(resolver.AddMaxPool2D());

TF_LITE_ENSURE_STATUS(resolver.AddAveragePool2D());

TF_LITE_ENSURE_STATUS(resolver.AddSoftmax());

TF_LITE_ENSURE_STATUS(resolver.AddReshape());

TF_LITE_ENSURE_STATUS(resolver.AddExpandDims());

TF_LITE_ENSURE_STATUS(resolver.AddMean());

TF_LITE_ENSURE_STATUS(resolver.AddShape());

TF_LITE_ENSURE_STATUS(resolver.AddStridedSlice());

TF_LITE_ENSURE_STATUS(resolver.AddPack());

TF_LITE_ENSURE_STATUS(resolver.AddDequantize());

复制代码
return kTfLiteOk;

}

// 输入浮点转 int8(量化)

void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size) {

for (int i = 0; i < size; ++i) {

float scaled = roundf(input_f32[i] / input_scale + input_zero_point);

int32_t q;

if (scaled > 127.0f) {

q = 127;

} else if (scaled < -128.0f) {

q = -128;

} else {

q = static_cast<int32_t>(scaled);

}

output_i8[i] = static_cast<int8_t>(q);

}

}

// 输出 int8 转浮点(反量化)

void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size) {

for (int i = 0; i < size; ++i) {

output_f32[i] = (input_i8[i] - output_zero_point) * output_scale;

}

}

// 找最大值索引(argmax)

int n1m_argmax(const float* data, int num_classes) {

int max_index = 0;

float max_value = data[0];

for (int i = 1; i < num_classes; ++i) {

if (data[i] > max_value) {

max_value = data[i];

max_index = i;

}

}

return max_index;

}

// 主要推理函数

extern "C" int CPGC_ALG_NILM_RUN_T1(const float* VI_input, const int data_len) {

MemoryPool memory_pool; // 使用内存池

复制代码
// 初始化上下文
ModelContext* ctx = static_cast<ModelContext*>(memory_pool.Allocate(sizeof(ModelContext)));
if (!ctx) {
    MicroPrintf("Failed to allocate ModelContext");
    return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}

memset(g_tensor_arena, 0, sizeof(g_tensor_arena));

// 加载模型
const tflite::Model* model = tflite::GetModel(g_model_data);
if (!model) {
    MicroPrintf("Failed to load model - GetModel returned null");
    return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}

if (model->version() != TFLITE_SCHEMA_VERSION) {
    MicroPrintf("Model version mismatch: got %d, expected %d", model->version(), TFLITE_SCHEMA_VERSION);
    return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}

// 创建解析器并注册操作
OpResolver resolver;
TFliteStatus resolver_status = RegisterOps(resolver);
if (resolver_status != kTfLiteOk) {
    MicroPrintf("RegisterOps failed: %d", resolver_status);
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 创建解释器
ctx->interpreter = new (memory_pool.Allocate(sizeof(tflite::MicroInterpreter))) tflite::MicroInterpreter(
    model,
    resolver,
    g_tensor_arena,
    sizeof(g_tensor_arena)
);

if (ctx->interpreter == nullptr) {
    MicroPrintf("Failed to create interpreter - out of memory");
    return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}

// 分配张量
TFliteStatus status = ctx->interpreter->AllocateTensors();
if (status != kTfLiteOk) {
    MicroPrintf("AllocateTensors failed: %d", status);
    MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

MicroPrintf("Model loaded successfully");
MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());

// 获取输入 tensor
TfLiteTensor* input_tensor = ctx->interpreter->input(0);
if (!input_tensor || !input_tensor->data.data) {
    MicroPrintf("ERROR: Failed to get input tensor");
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 量化输入
int8_t quantized_input[256] = {0};
n1m_input_float_to_int8(VI_input, quantized_input, data_len);
memcpy(input_tensor->data.int8, quantized_input, data_len);

// 推理
TFliteStatus invoke_status = ctx->interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
    MicroPrintf("Invoke failed: %d", invoke_status);
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 获取输出并反量化
TfLiteTensor* output_tensor = ctx->interpreter->output(0);
if (!output_tensor || !output_tensor->data.data) {
    MicroPrintf("ERROR: Failed to get output tensor");
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 打印原始 int8 输出
MicroPrintf("Raw int8 output (first 5):");
for (int i = 0; i < load_class_num; ++i) {
    MicroPrintf(" [%d]", output_tensor->data.int8[i]);
}

// 反量化为 float
float output_float[10];  // 假设输出大小 <=10
n1m_output_int8_to_float(output_tensor->data.int8, output_float, load_class_num);

MicroPrintf("Float output after dequantize (first 5):");
for (int i = 0; i < load_class_num; ++i) {
    MicroPrintf(" [%d]: %.6f", i, output_float[i]);
}

// 获取预测类别
int pred_label = n1m_argmax(output_float, load_class_num);
MicroPrintf("Prediction: Class %d", pred_label);

return pred_label;

}

相关推荐
视觉AI9 小时前
为什么 transformers 要 import TensorFlow
人工智能·tensorflow·neo4j
封奚泽优2 天前
Neo4j中导入.owl数据
知识图谱·neo4j·owl·rdf
Doro再努力2 天前
Neo4j图数据库:简述增删改查
数据库·neo4j
武子康3 天前
Java-165 Neo4j 图论详解 欧拉路径与欧拉回路 10 分钟跑通:Python NetworkX 判定实战
java·数据库·性能优化·系统架构·nosql·neo4j·图论
麦麦大数据5 天前
F042 A星算法课程推荐(A*算法) | 课程知识图谱|课程推荐vue+flask+neo4j B/S架构前后端分离|课程知识图谱构造
vue.js·算法·知识图谱·neo4j·a星算法·路径推荐·课程推荐
rengang666 天前
132-Spring AI Alibaba Vector Neo4j 示例
人工智能·spring·neo4j·rag·spring ai·ai应用编程
小宋10217 天前
Neo4j-图数据库入门图文保姆攻略
数据库·neo4j
消失在人海中8 天前
图形数据库Neo4J简介
数据库·oracle·neo4j
麦麦大数据8 天前
F035 vue+neo4j中医南药药膳知识图谱可视化系统 | vue+flask
vue.js·知识图谱·neo4j·中医·中药·药膳·南药