int8_to_float(output_tensor->data.int8, output_float, load_class_num);

#include "tensorflow/lite/core/c/common.h"

#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"

#include "tensorflow/lite/micro/micro_interpreter.h"

#include "tensorflow/lite/micro/micro_log.h"

#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"

#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"

#include "tensorflow/lite/micro/system_setup.h"

#include "tensorflow/lite/schema/schema_generated.h"

// 外部模型数据(需在其他文件中定义)

extern const char g_model_data[];

extern const int g_model_data_len;

// 内存分配区,32字节对齐
attribute ((section(".tensor_arena"), aligned(32)))

uint8_t g_tensor_arena[32 * 1024] attribute ((aligned(32))) = {0};

extern const int g_tensor_arena_len;

// 量化参数(必须与 Python 训练时保持一致)

const float input_scale = 0.05299549177289009f;

const int32_t input_zero_point = 4;

const float output_scale = 0.00390625f;

const int32_t output_zero_point = -128;

// 其他声明

const int kMaxOps = 13; // 支持的操作数

const int load_class_num = 5; // 分类数量

namespace {

using OpResolver = tflite::MicroMutableOpResolver;

}

// 算法上下文结构体

struct ModelContext {

OpResolver resolver;

tflite::MicroInterpreter* interpreter;

};

// 内存池管理

class MemoryPool {

public:

void* Allocate(size_t size) {

if (current_position_ + size <= pool_size_) {

void* ptr = &pool_[current_position_];

current_position_ += size;

return ptr;

} else {

return nullptr;

}

}

复制代码
void Reset() { current_position_ = 0; }

private:

static constexpr size_t pool_size_ = 64 * 1024; // 内存池大小

uint8_t pool_[pool_size_]; // 内存池

size_t current_position_ = 0;

};

// 内部自定义函数

int n1m_argmax(const float* data, int size);

void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size);

void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size);

// 注册支持的算子

TFliteStatus RegisterOps(OpResolver& resolver) {

TF_LITE_ENSURE_STATUS(resolver.AddFullyConnected());

TF_LITE_ENSURE_STATUS(resolver.AddConv2D());

TF_LITE_ENSURE_STATUS(resolver.AddDepthwiseConv2D());

TF_LITE_ENSURE_STATUS(resolver.AddMaxPool2D());

TF_LITE_ENSURE_STATUS(resolver.AddAveragePool2D());

TF_LITE_ENSURE_STATUS(resolver.AddSoftmax());

TF_LITE_ENSURE_STATUS(resolver.AddReshape());

TF_LITE_ENSURE_STATUS(resolver.AddExpandDims());

TF_LITE_ENSURE_STATUS(resolver.AddMean());

TF_LITE_ENSURE_STATUS(resolver.AddShape());

TF_LITE_ENSURE_STATUS(resolver.AddStridedSlice());

TF_LITE_ENSURE_STATUS(resolver.AddPack());

TF_LITE_ENSURE_STATUS(resolver.AddDequantize());

复制代码
return kTfLiteOk;

}

// 输入浮点转 int8(量化)

void n1m_input_float_to_int8(const float* input_f32, int8_t* output_i8, int size) {

for (int i = 0; i < size; ++i) {

float scaled = roundf(input_f32[i] / input_scale + input_zero_point);

int32_t q;

if (scaled > 127.0f) {

q = 127;

} else if (scaled < -128.0f) {

q = -128;

} else {

q = static_cast<int32_t>(scaled);

}

output_i8[i] = static_cast<int8_t>(q);

}

}

// 输出 int8 转浮点(反量化)

void n1m_output_int8_to_float(const int8_t* input_i8, float* output_f32, int size) {

for (int i = 0; i < size; ++i) {

output_f32[i] = (input_i8[i] - output_zero_point) * output_scale;

}

}

// 找最大值索引(argmax)

int n1m_argmax(const float* data, int num_classes) {

int max_index = 0;

float max_value = data[0];

for (int i = 1; i < num_classes; ++i) {

if (data[i] > max_value) {

max_value = data[i];

max_index = i;

}

}

return max_index;

}

// 主要推理函数

extern "C" int CPGC_ALG_NILM_RUN_T1(const float* VI_input, const int data_len) {

MemoryPool memory_pool; // 使用内存池

复制代码
// 初始化上下文
ModelContext* ctx = static_cast<ModelContext*>(memory_pool.Allocate(sizeof(ModelContext)));
if (!ctx) {
    MicroPrintf("Failed to allocate ModelContext");
    return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}

memset(g_tensor_arena, 0, sizeof(g_tensor_arena));

// 加载模型
const tflite::Model* model = tflite::GetModel(g_model_data);
if (!model) {
    MicroPrintf("Failed to load model - GetModel returned null");
    return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}

if (model->version() != TFLITE_SCHEMA_VERSION) {
    MicroPrintf("Model version mismatch: got %d, expected %d", model->version(), TFLITE_SCHEMA_VERSION);
    return CPGC_ALG_NILM_ERR_MODEL_INCOMPLETE;
}

// 创建解析器并注册操作
OpResolver resolver;
TFliteStatus resolver_status = RegisterOps(resolver);
if (resolver_status != kTfLiteOk) {
    MicroPrintf("RegisterOps failed: %d", resolver_status);
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 创建解释器
ctx->interpreter = new (memory_pool.Allocate(sizeof(tflite::MicroInterpreter))) tflite::MicroInterpreter(
    model,
    resolver,
    g_tensor_arena,
    sizeof(g_tensor_arena)
);

if (ctx->interpreter == nullptr) {
    MicroPrintf("Failed to create interpreter - out of memory");
    return CPGC_ALG_NILM_ERR_OUT_OF_MEMORY;
}

// 分配张量
TFliteStatus status = ctx->interpreter->AllocateTensors();
if (status != kTfLiteOk) {
    MicroPrintf("AllocateTensors failed: %d", status);
    MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

MicroPrintf("Model loaded successfully");
MicroPrintf("Arena used: %u bytes", ctx->interpreter->arena_used_bytes());

// 获取输入 tensor
TfLiteTensor* input_tensor = ctx->interpreter->input(0);
if (!input_tensor || !input_tensor->data.data) {
    MicroPrintf("ERROR: Failed to get input tensor");
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 量化输入
int8_t quantized_input[256] = {0};
n1m_input_float_to_int8(VI_input, quantized_input, data_len);
memcpy(input_tensor->data.int8, quantized_input, data_len);

// 推理
TFliteStatus invoke_status = ctx->interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
    MicroPrintf("Invoke failed: %d", invoke_status);
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 获取输出并反量化
TfLiteTensor* output_tensor = ctx->interpreter->output(0);
if (!output_tensor || !output_tensor->data.data) {
    MicroPrintf("ERROR: Failed to get output tensor");
    return CPGC_ALG_NILM_ERR_UNKNOWN_ERROR;
}

// 打印原始 int8 输出
MicroPrintf("Raw int8 output (first 5):");
for (int i = 0; i < load_class_num; ++i) {
    MicroPrintf(" [%d]", output_tensor->data.int8[i]);
}

// 反量化为 float
float output_float[10];  // 假设输出大小 <=10
n1m_output_int8_to_float(output_tensor->data.int8, output_float, load_class_num);

MicroPrintf("Float output after dequantize (first 5):");
for (int i = 0; i < load_class_num; ++i) {
    MicroPrintf(" [%d]: %.6f", i, output_float[i]);
}

// 获取预测类别
int pred_label = n1m_argmax(output_float, load_class_num);
MicroPrintf("Prediction: Class %d", pred_label);

return pred_label;

}

相关推荐
hai-chu17 小时前
将 Neo4j 安装为 macOS 服务
macos·策略模式·neo4j
Petrichor_H_4 天前
DAY 36 复习日
neo4j
计算机毕业设计指导5 天前
基于Neo4j的民航知识图谱问答系统设计与实现
知识图谱·neo4j·easyui
L.EscaRC5 天前
Neo4j Cypher查询语言深度解析
neo4j
L.EscaRC6 天前
图数据库Neo4j原理与运用
数据库·oracle·neo4j
知己80806 天前
docker搭建图数据库neo4j
数据库·docker·neo4j
羊羊小栈7 天前
基于知识图谱(Neo4j)和大语言模型(LLM)的图检索增强(GraphRAG)的医疗健康知识问诊系统(vue+flask+AI算法)
人工智能·语言模型·毕业设计·知识图谱·neo4j·大作业
麦麦大数据8 天前
F047 vue3+flask微博舆情推荐可视化问答系统
python·flask·知识图谱·neo4j·推荐算法·舆情分析·舆情监测
武子康9 天前
Java-170 Neo4j 事务、索引与约束实战:语法、并发陷阱与速修清单
java·开发语言·数据库·sql·nosql·neo4j·索引
武子康9 天前
Java-171 Neo4j 备份与恢复 + 预热与执行计划实战
java·开发语言·数据库·性能优化·系统架构·nosql·neo4j