01_svm_二分类

环境

  • MinGW:7.3.0
  • Dlib:20.0

dlib 库SVM 二分类

模型训练

dlib库二分类问题中1表示正类,-1表示反类(这一点与其他库不太一样)

训练模型前需要将样本数据转换为std::vector<dlib::matrix<double,T,1>>形式(matrix 是一个多维数组的模板类,可以用来表示不同类型的矩阵)

c++ 复制代码
typedef dlib::matrix<double,30,1> cancer_type; // 定义一个矩阵类型cancer_type,用于存储样本数据
// 加载数据
vector<BreastCancer> vec = LoadCancer(pfile);  // 通过文件加载数据(sklearn中的breast_cancer.csv)
vector<cancer_type> train_datas;
vector<double> train_labels;
for (int i = 0; i < vec.size(); i++)
{
    BreastCancer cancer = vec.at(i);
    cancer_type cc_type;
    for (int j = 0; j < 30; j++)
    {
        cc_type(j) = cancer.ft[j];
    }
    train_datas.push_back(cc_type);
    train_labels.push_back(cancer.tag > 0 ? 1 : -1); // 注意正、反类的取值
}

svm算法对数据缩放敏感(可以参考蜥蜴书《Python机器学习基础教程》),这里也对数据进行缩放处理

c++ 复制代码
dlib::vector_normalizer<cancer_type> normalizer; // 数据预处理,使不同特征的量纲统一,帮助模型训练时更高效收敛。
										     // sklearn中也称这种预处理算法为无监督算法
normalizer.train(train_datas);
for (int i = 0; i < train_datas.size(); i++)
{
    train_datas[i] = normalizer(train_datas[i]);
}

创建模型并训练

c++ 复制代码
typedef dlib::radial_basis_kernel<cancer_type> kernel_type; // rbf 径向基函数核,主要用于支持向量机(SVM)、高斯过程回归等机器学习算法,能够实现输入空间的非线性映射。
typedef dlib::decision_function<kernel_type> dec_func_type; // dlib::decision_function是dlib库中用于支持向量机(SVM)分类器决策的核心接口,其主要功能是通过训练后的模型对新样本进行分类或回归预测
typedef dlib::normalized_function<dec_func_type> func_type; // 用于对决策函数进行归一化处理的工具类,主要用于支持向量机(SVM)等分类器的输出标准化。

typedef dlib::probabilistic_decision_function<kernel_type> pro_func_type; // 与decision_function类似,但输出的是概率值
typedef dlib::normalized_function<pro_func_type> p_func_type;

// 打乱数据
dlib::randomize_samples(train_datas, train_labels);
// 创建训练器
dlib::svm_c_trainer<kernel_type> trainer;
trainer.set_c(10);                     // 设置C参数
trainer.set_kernel(kernel_type(0.01)); // 设置gamma参数
// 训练模型
p_func_type learned_func;
learned_func.normalizer = normalizer;
// learned_func.function =  trainer.train(train_datas, train_labels); // 训练模型
learned_func.function = dlib::train_probabilistic_decision_function(trainer, train_datas, train_labels, 3);// 训练模型,数值3表示对样本数据的折叠数,可以参考sklearn中的模型评估

模型验证

c++ 复制代码
int ok_count = 0;
for (int i = 0; i < vec.size(); i++)
{
    BreastCancer cancer = vec.at(i);
    cancer_type cc_type;
    for (int j = 0; j < 30; j++)
    {
        cc_type(j) = cancer.ft[j];
    }
    double ret_d = learned_func(cc_type);
    cout << "probabilistic : " << ret_d << endl; // 概率值,大于0.5的是正类,否则视为反类
    int ret = ret_d > 0.5 ? 1 : 0;
    if (ret == cancer.tag)
    ok_count += 1;
}
cout << "accurary:" << (ok_count * 1.0) / vec.size() << endl;

模型保存与加载

通过dlib::serialize可以保存模型

c++ 复制代码
dlib::serialize("svm_cancer_model.dat") << learned_func; // 保存模型到svm_cancer_model.dat

通过dlib::deserialize加载模型

c++ 复制代码
p_func_type learned_func; //
dlib::deserialize(file_name) >> learned_func;

其他

加载breast_cancer.csv数据

c++ 复制代码
typedef struct {
    float ft[30];
    int tag;
} BreastCancer;


vector<BreastCancer> LoadCancer(const string &fpath)
{
    vector<BreastCancer> vec;
    fstream input_file(fpath);
    string line;
    if (input_file.is_open())
    {
        getline(input_file, line); // 跳过头
        while (getline(input_file, line))
        {
            vector<string> sp_vec = SplitString(line, ',');
            if (sp_vec.size() == 0)
                continue;
            BreastCancer cancer;
            for (int i = 0; i < 30; i++)
            {
                cancer.ft[i] = atof(sp_vec.at(i).c_str());
            }
            cancer.tag = atoi(sp_vec.at(30).c_str());
            vec.push_back(cancer);
        }
    }
    return vec;
}

vector<string> SplitString(const string &str, char delim)
{
    vector<string> vec;
    istringstream iss(str);
    string token;
    while (getline(iss, token, delim))
    {
        vec.push_back(token);
    }
    return vec;
}
相关推荐
isyoungboy4 小时前
使用SVM构建光照鲁棒的颜色分类器:从特征提取到SVM
算法·机器学习·支持向量机
白杆杆红伞伞4 小时前
02_svm_多分类
机器学习·支持向量机·分类·dlib
极客数模4 小时前
2025年MathorCup 大数据竞赛明日开赛,注意事项!论文提交规范、模板、承诺书正确使用!2025年第六届MathorCup数学应用挑战赛——大数据竞赛
大数据·python·算法·matlab·图论·比赛推荐
.小小陈.4 小时前
数据结构3:复杂度
c语言·开发语言·数据结构·笔记·学习·算法·visual studio
立志成为大牛的小牛4 小时前
数据结构——二十四、图(王道408)
数据结构·学习·程序人生·考研·算法
TT哇4 小时前
【优先级队列(堆)】2.数据流中的第 K ⼤元素(easy)
算法·1024程序员节
包饭厅咸鱼4 小时前
QT----使用onnxRuntime运行图像分类模型
开发语言·qt·分类
Matlab程序猿小助手5 小时前
【MATLAB源码-第303期】基于matlab的蒲公英优化算法(DO)机器人栅格路径规划,输出做短路径图和适应度曲线.
开发语言·算法·matlab·机器人·kmeans
CoderIsArt5 小时前
CORDIC三角计算技术
人工智能·算法·机器学习