01_svm_二分类

环境

  • MinGW:7.3.0
  • Dlib:20.0

dlib 库SVM 二分类

模型训练

dlib库二分类问题中1表示正类,-1表示反类(这一点与其他库不太一样)

训练模型前需要将样本数据转换为std::vector<dlib::matrix<double,T,1>>形式(matrix 是一个多维数组的模板类,可以用来表示不同类型的矩阵)

c++ 复制代码
typedef dlib::matrix<double,30,1> cancer_type; // 定义一个矩阵类型cancer_type,用于存储样本数据
// 加载数据
vector<BreastCancer> vec = LoadCancer(pfile);  // 通过文件加载数据(sklearn中的breast_cancer.csv)
vector<cancer_type> train_datas;
vector<double> train_labels;
for (int i = 0; i < vec.size(); i++)
{
    BreastCancer cancer = vec.at(i);
    cancer_type cc_type;
    for (int j = 0; j < 30; j++)
    {
        cc_type(j) = cancer.ft[j];
    }
    train_datas.push_back(cc_type);
    train_labels.push_back(cancer.tag > 0 ? 1 : -1); // 注意正、反类的取值
}

svm算法对数据缩放敏感(可以参考蜥蜴书《Python机器学习基础教程》),这里也对数据进行缩放处理

c++ 复制代码
dlib::vector_normalizer<cancer_type> normalizer; // 数据预处理,使不同特征的量纲统一,帮助模型训练时更高效收敛。
										     // sklearn中也称这种预处理算法为无监督算法
normalizer.train(train_datas);
for (int i = 0; i < train_datas.size(); i++)
{
    train_datas[i] = normalizer(train_datas[i]);
}

创建模型并训练

c++ 复制代码
typedef dlib::radial_basis_kernel<cancer_type> kernel_type; // rbf 径向基函数核,主要用于支持向量机(SVM)、高斯过程回归等机器学习算法,能够实现输入空间的非线性映射。
typedef dlib::decision_function<kernel_type> dec_func_type; // dlib::decision_function是dlib库中用于支持向量机(SVM)分类器决策的核心接口,其主要功能是通过训练后的模型对新样本进行分类或回归预测
typedef dlib::normalized_function<dec_func_type> func_type; // 用于对决策函数进行归一化处理的工具类,主要用于支持向量机(SVM)等分类器的输出标准化。

typedef dlib::probabilistic_decision_function<kernel_type> pro_func_type; // 与decision_function类似,但输出的是概率值
typedef dlib::normalized_function<pro_func_type> p_func_type;

// 打乱数据
dlib::randomize_samples(train_datas, train_labels);
// 创建训练器
dlib::svm_c_trainer<kernel_type> trainer;
trainer.set_c(10);                     // 设置C参数
trainer.set_kernel(kernel_type(0.01)); // 设置gamma参数
// 训练模型
p_func_type learned_func;
learned_func.normalizer = normalizer;
// learned_func.function =  trainer.train(train_datas, train_labels); // 训练模型
learned_func.function = dlib::train_probabilistic_decision_function(trainer, train_datas, train_labels, 3);// 训练模型,数值3表示对样本数据的折叠数,可以参考sklearn中的模型评估

模型验证

c++ 复制代码
int ok_count = 0;
for (int i = 0; i < vec.size(); i++)
{
    BreastCancer cancer = vec.at(i);
    cancer_type cc_type;
    for (int j = 0; j < 30; j++)
    {
        cc_type(j) = cancer.ft[j];
    }
    double ret_d = learned_func(cc_type);
    cout << "probabilistic : " << ret_d << endl; // 概率值,大于0.5的是正类,否则视为反类
    int ret = ret_d > 0.5 ? 1 : 0;
    if (ret == cancer.tag)
    ok_count += 1;
}
cout << "accurary:" << (ok_count * 1.0) / vec.size() << endl;

模型保存与加载

通过dlib::serialize可以保存模型

c++ 复制代码
dlib::serialize("svm_cancer_model.dat") << learned_func; // 保存模型到svm_cancer_model.dat

通过dlib::deserialize加载模型

c++ 复制代码
p_func_type learned_func; //
dlib::deserialize(file_name) >> learned_func;

其他

加载breast_cancer.csv数据

c++ 复制代码
typedef struct {
    float ft[30];
    int tag;
} BreastCancer;


vector<BreastCancer> LoadCancer(const string &fpath)
{
    vector<BreastCancer> vec;
    fstream input_file(fpath);
    string line;
    if (input_file.is_open())
    {
        getline(input_file, line); // 跳过头
        while (getline(input_file, line))
        {
            vector<string> sp_vec = SplitString(line, ',');
            if (sp_vec.size() == 0)
                continue;
            BreastCancer cancer;
            for (int i = 0; i < 30; i++)
            {
                cancer.ft[i] = atof(sp_vec.at(i).c_str());
            }
            cancer.tag = atoi(sp_vec.at(30).c_str());
            vec.push_back(cancer);
        }
    }
    return vec;
}

vector<string> SplitString(const string &str, char delim)
{
    vector<string> vec;
    istringstream iss(str);
    string token;
    while (getline(iss, token, delim))
    {
        vec.push_back(token);
    }
    return vec;
}
相关推荐
Kisorge13 小时前
【电机控制】基于STM32F103C8T6的二轮平衡车设计——LQR线性二次线控制器(算法篇)
stm32·嵌入式硬件·算法
铭哥的编程日记14 小时前
深入浅出蓝桥杯:算法基础概念与实战应用(二)基础算法(下)
算法·职场和发展·蓝桥杯
Swift社区14 小时前
LeetCode 421 - 数组中两个数的最大异或值
算法·leetcode·职场和发展
cici1587414 小时前
基于高光谱成像和偏最小二乘法(PLS)的苹果糖度检测MATLAB实现
算法·matlab·最小二乘法
StarPrayers.15 小时前
自蒸馏学习方法
人工智能·算法·学习方法
大锦终15 小时前
【动规】背包问题
c++·算法·动态规划
智者知已应修善业16 小时前
【c语言蓝桥杯计算卡片题】2023-2-12
c语言·c++·经验分享·笔记·算法·蓝桥杯
hansang_IR16 小时前
【题解】洛谷 P2330 [SCOI2005] 繁忙的都市 [生成树]
c++·算法·最小生成树
Croa-vo16 小时前
PayPal OA 全流程复盘|题型体验 + 成绩反馈 + 通关经验
数据结构·经验分享·算法·面试·职场和发展
AndrewHZ17 小时前
【图像处理基石】 怎么让图片变成波普风?
图像处理·算法·计算机视觉·风格迁移·cv