最近在做AI信息在各个平台流转的框架设计,想要设计一种可以灵活扩展、不改变原有代码的框架,了解到了类注册。
具体需求是这样的:AI算法在客户本地电脑和云端都有部署,原先AI在这两个平台下的输出格式并不统一,且每个AI功能都有自己的输出格式,导致两个平台下的AI信息无法共享,带来了计算资源的浪费,管理起来也比较混乱,因此需要一种模式将所有AI输出规范起来。
我的解决思路大概就是将所有AI信息都规范输出到同一个json里面,具体实现就是:定义一个基类,该基类有两个方法需要重载,一个是toJson用于序列化到json,一个是fromJson用于从json中解析,后续每个AI功能的输出都要继承该基类,最后用一个Manager类通过map数据结构管理所有的AI信息。
注册类机制模板实现:
考虑到便于后续AI输出功能的灵活扩展,以及不改变原有代码,用到了类注册机制管理这些子类的创建,最后将类注册机制抽象为如下模板:
cpp
// pluginFactoryTemplate.h
#ifndef PLUGIN_FACTORY_TEMPLATE_H_
#define PLUGIN_FACTORY_TEMPLATE_H_
#include<map>
#define REGISTER_PLUGIN(Base,DerivedBase,pluginName) \
static bool regPlugin_##pluginName=(PluginFatoryTemplate<Base>::register_plugin<DerivedBase>(pluginName),true);
#define REGISTER_PLUGIN(Base,DerivedBase) \
static bool regPlugin_##DerivedBase=(PluginFatoryTemplate<Base>::register_plugin<DerivedBase>(#DerivedBase),true);
#define CREATE_PLUGIN(Base,pluginName) \
PluginFatoryTemplate<Base>::create(pluginName)
template<typename Base>
class PluginFatoryTemplate{
public:
// 创建函数类型
using CreateFunc=Base* (*)();
// 注册插件
template<typename DerivedBase>
static void register_plugin(const std::string& pluginName){
getRegistry()[pluginName]=[](){
return (Base*) (new DerivedBase());
};
}
// 创建插件实例
static std::shared_ptr<Base> create(const std::string& pluginName){
auto it=getRegistry().find(pluginName);
if(it!=getRegistry().end()){
return std::shared_ptr<Base>(it->second());
}
std::cout<<"Error:Can't find "<<pluginName<<std::endl;
return nullptr;
}
private:
static std::map<std::string,CreateFunc>& getRegistry(){
static std::map<std::string,CreateFunc> registry;
return registry;
}
}
#endif
使用:
定义基类:
cpp
// BaseAIOutput.h
class BaseAIOutput{
public:
visual std::string toJson()=0;
visual void fromJson(std::string json)=0;
}
定义子类:
cpp
// AI1Output.h
#include"BaseAIOutput.h"
class AI1Output:public BaseAIOutput{
public:
virsual std::string toJson();
virsual void fromJson(std::string json);
}
cpp
// AI1Output.cpp
#include"AI1Output.h"
#include"pluginFatoryTemplate.h"
// 注册子类
REGISTER_PLUGIN(BaseAIOutput,AI1Output)
std::string AI1Output::toJson(){
return "this is json";
}
void AI1Output::fromJson(std::string json){
std::cout<<"from json"<<std::endl;
}
主函数:
cpp
// main.cpp
#include"BaseAIOutput.h"
#include"pluginFatoryTemplate.h"
void main(){
// 创建子类对象
auto ptrAI1Output=CREATE_PLUGIN(BaseAIOutput,"AI1Output");
return 0;
}