遗传算法拟合公式

基本思想

遗传算法一个比较有趣的应用是,给定一组数据,用初等函数拟合出一个公式,尽可能接近这组数据

熟悉深度学习的同学可能会发现,这其实类似深度学习,不同点在于,深度学习使用多层网络+激活函数来拟合,内部是黑箱,具体得到了一个什么公式谁也说不清。

但是遗传算法来拟合,会先把每个初等函数都编码成一个基因,最后得到最优的基因后,可以解码出具体是什么公式!这叫做符号回归,也就是不时给定函数结构,预测参数,而是函数结构,参数都预测,这里符号就是用什么函数类型的意思。

这里因为每个个体要代表一个表达式,所以基因没有采用二进制串,而是一个抽象语法树。交叉时,分别随机选一个子树,进行交换。变异时有一定概率把一个子树变成完全随机的新子树,一定概率只改一个点的符号,一定概率只改一个点的常量值。

开始随机生成一个表达式,然后在他上面取一些数据点,传给遗传算法来预测。每轮的基因适应度,就是检查这些数据点的函数值,和真实函数值差多少,拿出来做一个均方差(MSE)。根据均方差,保留一些精英到下一代,剩下不能保留到下一代的繁殖,繁殖时适应度高的有优先择偶权。这块和一般的遗传算法差不多

特殊之处在于,这里初始生成的表达式都很简单,只用加减乘除,三角函数,且只在积累固定的模板中挑选,这是为了避免预测难度过大。所以为了避免预测试时,过拟合出一个复杂的公式,偏离正确答案,我们需要在预测时采用类似深度学习正则化的思想,对预测语法树的复杂度做一个惩罚,这里的具体实现是,限制节点数,深度上限,并且在适应度中加一个正比于节点数的项。

实现

c 复制代码
#include <algorithm>
#include <chrono>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <limits>
#include <memory>
#include <numeric>
#include <random>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

namespace {

// 数值求值时的保护阈值,避免树在演化过程中出现极端大数。
constexpr double kEvalLimit = 1e6;
// 保护除法的阈值,分母太接近 0 时不做真实除法。
constexpr double kDivisionEpsilon = 1e-6;
// 当表达式非法或误差爆炸时,直接赋一个极大的惩罚适应度。
constexpr double kFitnessPenalty = 1e12;

double clamp_value(double value, double low, double high) {
    return std::max(low, std::min(value, high));
}

double clamp_eval(double value) {
    if (!std::isfinite(value)) {
        return 0.0;
    }
    return clamp_value(value, -kEvalLimit, kEvalLimit);
}

double protected_div(double lhs, double rhs) {
    if (std::abs(rhs) < kDivisionEpsilon) {
        return clamp_eval(lhs);
    }
    return clamp_eval(lhs / rhs);
}

// 表达式树允许使用的原始算子集合。
// 这组算子越丰富,表达能力越强;但搜索空间也会更大,进化更难稳定。
enum class NodeType {
    Variable,
    Constant,
    Add,
    Sub,
    Mul,
    Div,
    Sin,
    Cos,
};

int arity(NodeType type) {
    switch (type) {
        case NodeType::Variable:
        case NodeType::Constant:
            return 0;
        case NodeType::Sin:
        case NodeType::Cos:
            return 1;
        case NodeType::Add:
        case NodeType::Sub:
        case NodeType::Mul:
        case NodeType::Div:
            return 2;
    }
    return 0;
}

struct ExprNode {
    NodeType type = NodeType::Constant;
    double constant = 0.0;
    // 一元运算只使用 left;二元运算同时使用 left / right。
    std::unique_ptr<ExprNode> left;
    std::unique_ptr<ExprNode> right;
};

std::unique_ptr<ExprNode> clone_tree(const ExprNode* node) {
    if (node == nullptr) {
        return nullptr;
    }
    auto copy = std::make_unique<ExprNode>();
    copy->type = node->type;
    copy->constant = node->constant;
    copy->left = clone_tree(node->left.get());
    copy->right = clone_tree(node->right.get());
    return copy;
}

int node_count(const ExprNode* node) {
    if (node == nullptr) {
        return 0;
    }
    return 1 + node_count(node->left.get()) + node_count(node->right.get());
}

int tree_depth(const ExprNode* node) {
    if (node == nullptr) {
        return 0;
    }
    return 1 + std::max(tree_depth(node->left.get()), tree_depth(node->right.get()));
}

bool contains_variable(const ExprNode* node) {
    if (node == nullptr) {
        return false;
    }
    if (node->type == NodeType::Variable) {
        return true;
    }
    return contains_variable(node->left.get()) || contains_variable(node->right.get());
}

double eval_tree(const ExprNode* node, double x) {
    // 递归计算表达式树在 x 处的值,所有中间结果都做截断保护。
    switch (node->type) {
        case NodeType::Variable:
            return clamp_eval(x);
        case NodeType::Constant:
            return clamp_eval(node->constant);
        case NodeType::Add:
            return clamp_eval(eval_tree(node->left.get(), x) + eval_tree(node->right.get(), x));
        case NodeType::Sub:
            return clamp_eval(eval_tree(node->left.get(), x) - eval_tree(node->right.get(), x));
        case NodeType::Mul:
            return clamp_eval(eval_tree(node->left.get(), x) * eval_tree(node->right.get(), x));
        case NodeType::Div:
            return protected_div(eval_tree(node->left.get(), x), eval_tree(node->right.get(), x));
        case NodeType::Sin:
            return clamp_eval(std::sin(eval_tree(node->left.get(), x)));
        case NodeType::Cos:
            return clamp_eval(std::cos(eval_tree(node->left.get(), x)));
    }
    return 0.0;
}

std::string format_number(double value) {
    std::ostringstream out;
    out << std::fixed << std::setprecision(3) << value;
    std::string text = out.str();
    while (!text.empty() && text.back() == '0') {
        text.pop_back();
    }
    if (!text.empty() && text.back() == '.') {
        text.pop_back();
    }
    if (text == "-0") {
        text = "0";
    }
    return text;
}

std::string to_string(const ExprNode* node) {
    switch (node->type) {
        case NodeType::Variable:
            return "x";
        case NodeType::Constant:
            return format_number(node->constant);
        case NodeType::Add:
            return "(" + to_string(node->left.get()) + " + " + to_string(node->right.get()) + ")";
        case NodeType::Sub:
            return "(" + to_string(node->left.get()) + " - " + to_string(node->right.get()) + ")";
        case NodeType::Mul:
            return "(" + to_string(node->left.get()) + " * " + to_string(node->right.get()) + ")";
        case NodeType::Div:
            return "(" + to_string(node->left.get()) + " / " + to_string(node->right.get()) + ")";
        case NodeType::Sin:
            return "sin(" + to_string(node->left.get()) + ")";
        case NodeType::Cos:
            return "cos(" + to_string(node->left.get()) + ")";
    }
    return "?";
}

void collect_paths(const ExprNode* node,
                   std::vector<std::vector<int>>& paths,
                   std::vector<int>& current) {
    if (node == nullptr) {
        return;
    }
    // 用从根走到该节点的路径给每个子树编号:0 表示左儿子,1 表示右儿子。
    paths.push_back(current);
    if (node->left) {
        current.push_back(0);
        collect_paths(node->left.get(), paths, current);
        current.pop_back();
    }
    if (node->right) {
        current.push_back(1);
        collect_paths(node->right.get(), paths, current);
        current.pop_back();
    }
}

std::unique_ptr<ExprNode>* get_slot(std::unique_ptr<ExprNode>& root,
                                    const std::vector<int>& path,
                                    std::size_t index = 0) {
    if (index == path.size()) {
        return &root;
    }
    if (path[index] == 0) {
        return get_slot(root->left, path, index + 1);
    }
    return get_slot(root->right, path, index + 1);
}

const ExprNode* get_node(const std::unique_ptr<ExprNode>& root, const std::vector<int>& path) {
    const ExprNode* node = root.get();
    for (int dir : path) {
        node = (dir == 0 ? node->left.get() : node->right.get());
    }
    return node;
}

struct Sample {
    double x;
    double y;
};

// 负责随机生成表达式树,以及对常数做局部数值扰动。
// 这个类决定了"搜索空间的先验分布":初始树更偏向什么结构、常数更容易取到哪些值,
// 都会直接影响遗传算法前期能否快速找到有希望的候选公式。
class ExpressionFactory {
public:
    explicit ExpressionFactory(std::mt19937& rng) : rng_(rng) {}

    std::unique_ptr<ExprNode> random_tree(int max_depth, bool force_function) {
        // Mix "grow" and "full" style initialization: either stop at a terminal
        // or keep expanding until depth runs out.
        // 混合 grow / full 两种初始化风格,增加初始种群多样性。
        if (max_depth <= 1 || (!force_function && prob_(rng_) < 0.32)) {
            return random_terminal();
        }

        if (prob_(rng_) < 0.30) {
            auto node = std::make_unique<ExprNode>();
            node->type = random_unary();
            node->left = random_tree(max_depth - 1, false);
            return node;
        }

        auto node = std::make_unique<ExprNode>();
        node->type = random_binary();
        node->left = random_tree(max_depth - 1, false);
        node->right = random_tree(max_depth - 1, false);
        return node;
    }

    std::unique_ptr<ExprNode> random_terminal() {
        auto node = std::make_unique<ExprNode>();
        if (prob_(rng_) < 0.60) {
            node->type = NodeType::Variable;
        } else {
            node->type = NodeType::Constant;
            node->constant = random_constant();
        }
        return node;
    }

    NodeType random_unary() {
        std::uniform_int_distribution<int> dist(0, 1);
        return dist(rng_) == 0 ? NodeType::Sin : NodeType::Cos;
    }

    NodeType random_binary() {
        std::uniform_int_distribution<int> dist(0, 3);
        switch (dist(rng_)) {
            case 0:
                return NodeType::Add;
            case 1:
                return NodeType::Sub;
            case 2:
                return NodeType::Mul;
            default:
                return NodeType::Div;
        }
    }

    double random_constant() {
        // 常数池故意不设得太大,便于在早期更快拼出可解释的表达式。
        static const std::vector<double> pool = {-3.0, -2.0, -1.5, -1.0, -0.5,
                                                 0.5,  1.0,  1.5,  2.0,  3.0};
        std::uniform_int_distribution<int> dist(0, static_cast<int>(pool.size()) - 1);
        return pool[dist(rng_)];
    }

    void mutate_constant(ExprNode& node) {
        std::normal_distribution<double> noise(0.0, 0.30);
        if (prob_(rng_) < 0.50) {
            node.constant = random_constant();
        } else {
            node.constant = clamp_value(node.constant + noise(rng_), -5.0, 5.0);
        }
    }

private:
    std::mt19937& rng_;
    std::uniform_real_distribution<double> prob_{0.0, 1.0};
};

class SymbolicRegressionGA {
public:
    struct Individual {
        std::unique_ptr<ExprNode> root;
        double fitness = kFitnessPenalty;
        double train_mse = kFitnessPenalty;
        double test_mse = kFitnessPenalty;
    };

    SymbolicRegressionGA(std::vector<Sample> train_samples,
                         std::vector<Sample> test_samples,
                         unsigned seed)
        : train_samples_(std::move(train_samples)),
          test_samples_(std::move(test_samples)),
          rng_(seed),
          factory_(rng_) {}

    void set_population_size(int value) { population_size_ = value; }
    void set_generations(int value) { generations_ = value; }
    void set_elite_count(int value) { elite_count_ = value; }
    void set_tournament_size(int value) { tournament_size_ = value; }
    void set_crossover_rate(double value) { crossover_rate_ = value; }
    void set_mutation_rate(double value) { mutation_rate_ = value; }
    void set_max_tree_depth(int value) { max_tree_depth_ = value; }
    void set_max_tree_nodes(int value) { max_tree_nodes_ = value; }

    // 主进化循环:
    // 1. 对当前种群按适应度排序;
    // 2. 保留精英;
    // 3. 反复做选择、交叉、变异生成子代;
    // 4. 评估子代后替换旧种群。
    //
    // 这里没有显式"提前停止"条件,而是固定跑满 generations。
    // 这样做实现简单,但如果后续想提速,可以加入"若最优值长期不改善则停止"。
    Individual run() {
        std::vector<Individual> population = initialize_population();
        Individual best = clone_individual(population.front());

        for (int generation = 0; generation < generations_; ++generation) {
            std::sort(population.begin(), population.end(),
                      [](const Individual& lhs, const Individual& rhs) {
                          return lhs.fitness < rhs.fitness;
                      });

            if (population.front().fitness < best.fitness) {
                best = clone_individual(population.front());
            }

            if (generation % 40 == 0 || generation + 1 == generations_) {
                std::cout << "generation " << std::setw(3) << generation
                          << "  train_mse=" << std::setw(12) << std::setprecision(8)
                          << population.front().train_mse << "  size="
                          << node_count(population.front().root.get()) << "  expr="
                          << to_string(population.front().root.get()) << '\n';
            }

            std::vector<Individual> next_population;
            next_population.reserve(population_size_);

            // 精英保留:前若干个最优个体原样进入下一代,减少回退。
            for (int i = 0; i < elite_count_ && i < population_size_; ++i) {
                next_population.push_back(clone_individual(population[i]));
            }

            while (static_cast<int>(next_population.size()) < population_size_) {
                const Individual& parent_a = select(population);
                const Individual& parent_b = select(population);

                Individual child;
                if (prob_(rng_) < crossover_rate_) {
                    child.root = crossover(parent_a.root, parent_b.root);
                } else {
                    child.root = clone_tree(parent_a.root.get());
                }

                if (prob_(rng_) < mutation_rate_) {
                    mutate(child.root);
                }

                // 控制树的深度和节点数,防止表达式无限膨胀。
                if (!is_valid_tree(child.root.get())) {
                    child.root = factory_.random_tree(3, true);
                }

                evaluate_individual(child);
                next_population.push_back(std::move(child));
            }

            population.swap(next_population);
        }

        std::sort(population.begin(), population.end(),
                  [](const Individual& lhs, const Individual& rhs) {
                      return lhs.fitness < rhs.fitness;
                  });
        if (population.front().fitness < best.fitness) {
            best = clone_individual(population.front());
        }
        return best;
    }

private:
    // 初始化策略不是完全随机,而是"少量人工种子 + 大量随机树"的混合方式。
    // 人工种子负责把搜索起点放在几个高频基础结构附近,例如 x、x*x、sin(x)。
    // 随机树负责提供多样性,避免整个种群只在少数简单结构周围打转。
    std::vector<Individual> initialize_population() {
        std::vector<Individual> population;
        population.reserve(population_size_);

        auto seed_expr = [&](std::unique_ptr<ExprNode> root) {
            Individual individual;
            individual.root = std::move(root);
            evaluate_individual(individual);
            population.push_back(std::move(individual));
        };

        // 先放入一些简单且常见的基础结构,帮助搜索更快起步。
        seed_expr(make_variable());
        seed_expr(make_constant(1.0));
        seed_expr(make_binary(NodeType::Mul, make_variable(), make_variable()));
        seed_expr(make_unary(NodeType::Sin, make_variable()));
        seed_expr(make_unary(NodeType::Cos, make_variable()));
        seed_expr(make_binary(NodeType::Add, make_variable(), make_constant(1.0)));

        std::uniform_int_distribution<int> depth_dist(2, 4);
        while (static_cast<int>(population.size()) < population_size_) {
            // 交替生成更满和更松的树,提高初始多样性。
            bool full = (population.size() % 2 == 0);
            auto root = factory_.random_tree(depth_dist(rng_), full);
            if (!contains_variable(root.get())) {
                // 符号回归里纯常数很容易卡成局部最优,这里强制把 x 拼进去。
                root = make_binary(NodeType::Add, std::move(root), make_variable());
            }
            Individual individual;
            individual.root = std::move(root);
            evaluate_individual(individual);
            population.push_back(std::move(individual));
        }

        return population;
    }

    Individual clone_individual(const Individual& other) const {
        Individual copy;
        copy.root = clone_tree(other.root.get());
        copy.fitness = other.fitness;
        copy.train_mse = other.train_mse;
        copy.test_mse = other.test_mse;
        return copy;
    }

    // 选择阶段决定"谁更有机会繁殖"。
    // 这里使用锦标赛选择而不是轮盘赌:实现更稳健,也不怕 fitness 数值尺度差异过大。
    const Individual& select(const std::vector<Individual>& population) {
        // Tournament selection keeps pressure toward fitter programs without
        // requiring probabilities derived from raw fitness values.
        // 锦标赛选择:随机抽若干个体,取其中最优者。
        std::uniform_int_distribution<int> dist(0, static_cast<int>(population.size()) - 1);
        int best = dist(rng_);
        for (int i = 1; i < tournament_size_; ++i) {
            int candidate = dist(rng_);
            if (population[candidate].fitness < population[best].fitness) {
                best = candidate;
            }
        }
        return population[best];
    }

    // 交叉的本质是"重组已有子结构"。
    // 如果某个父代已经学到了有用的局部模式,例如 sin(x + c) 或 x / c,
    // 那么子树交叉有机会把这些局部模式拼接成更完整的表达式。
    std::unique_ptr<ExprNode> crossover(const std::unique_ptr<ExprNode>& lhs,
                                        const std::unique_ptr<ExprNode>& rhs) {
        auto child = clone_tree(lhs.get());

        std::vector<std::vector<int>> lhs_paths;
        std::vector<int> current;
        collect_paths(child.get(), lhs_paths, current);

        std::vector<std::vector<int>> rhs_paths;
        collect_paths(rhs.get(), rhs_paths, current);

        std::uniform_int_distribution<int> lhs_dist(0, static_cast<int>(lhs_paths.size()) - 1);
        std::uniform_int_distribution<int> rhs_dist(0, static_cast<int>(rhs_paths.size()) - 1);
        const auto& lhs_path = lhs_paths[lhs_dist(rng_)];
        const auto& rhs_path = rhs_paths[rhs_dist(rng_)];

        // Standard GP subtree crossover: replace one random subtree in lhs
        // with one random subtree from rhs.
        // 标准遗传编程的子树交叉:在 lhs 中随机挑一个子树位置,用 rhs 的随机子树替换它。
        std::unique_ptr<ExprNode>* slot = get_slot(child, lhs_path);
        *slot = clone_tree(get_node(rhs, rhs_path));

        if (!contains_variable(child.get())) {
            child = make_binary(NodeType::Add, std::move(child), make_variable());
        }
        return child;
    }

    // 变异负责给搜索引入新信息,避免种群过早收敛。
    // 结构变异负责大步探索,算子变异负责结构微调,常数变异负责局部参数优化。
    void mutate(std::unique_ptr<ExprNode>& root) {
        std::vector<std::vector<int>> paths;
        std::vector<int> current;
        collect_paths(root.get(), paths, current);

        std::uniform_int_distribution<int> path_dist(0, static_cast<int>(paths.size()) - 1);
        std::uniform_real_distribution<double> choice(0.0, 1.0);
        double mode = choice(rng_);

        if (mode < 0.40) {
            // 结构变异:随机找一棵子树,整棵替换成新生成的树。
            std::unique_ptr<ExprNode>* slot = get_slot(root, paths[path_dist(rng_)]);
            *slot = factory_.random_tree(3, false);
        } else if (mode < 0.75) {
            // 算子变异:保留子树结构,只改节点运算符。
            std::vector<std::vector<int>> operator_paths;
            for (const auto& path : paths) {
                const ExprNode* node = get_node(root, path);
                if (arity(node->type) > 0) {
                    operator_paths.push_back(path);
                }
            }
            if (!operator_paths.empty()) {
                std::uniform_int_distribution<int> op_dist(
                    0, static_cast<int>(operator_paths.size()) - 1);
                std::unique_ptr<ExprNode>* slot = get_slot(root, operator_paths[op_dist(rng_)]);
                if (arity((*slot)->type) == 1) {
                    (*slot)->type = factory_.random_unary();
                } else {
                    (*slot)->type = factory_.random_binary();
                }
            }
        } else {
            // 数值变异:如果树里有常数,就对某个常数做重采样或小扰动。
            std::vector<std::vector<int>> constant_paths;
            for (const auto& path : paths) {
                const ExprNode* node = get_node(root, path);
                if (node->type == NodeType::Constant) {
                    constant_paths.push_back(path);
                }
            }
            if (constant_paths.empty()) {
                std::unique_ptr<ExprNode>* slot = get_slot(root, paths[path_dist(rng_)]);
                *slot = factory_.random_tree(2, false);
            } else {
                std::uniform_int_distribution<int> const_dist(
                    0, static_cast<int>(constant_paths.size()) - 1);
                std::unique_ptr<ExprNode>* slot =
                    get_slot(root, constant_paths[const_dist(rng_)]);
                factory_.mutate_constant(**slot);
            }
        }

        if (!contains_variable(root.get())) {
            root = make_binary(NodeType::Add, std::move(root), make_variable());
        }
    }

    // 适应度函数是整个搜索的方向盘。
    // 这份实现同时考虑训练误差、表达式复杂度,以及额外采样点上的泛化表现。
    void evaluate_individual(Individual& individual) const {
        individual.train_mse = mean_squared_error(individual.root.get(), train_samples_);
        individual.test_mse = mean_squared_error(individual.root.get(), test_samples_);
        const int size = node_count(individual.root.get());
        // Optimize training error first, but mildly discourage both overgrown
        // trees and solutions that only fit the training grid.
        // 适应度设计:训练误差是主目标,同时轻微惩罚大树,并混入少量 test_mse。
        individual.fitness =
            individual.train_mse + size * complexity_penalty_ + 0.10 * individual.test_mse;
    }

    // 对每个采样点计算平方误差,并在累计过程中做溢出保护。
    // 一旦发现预测值非法,或者总误差已经大到没有比较价值,就直接返回惩罚值。
    double mean_squared_error(const ExprNode* root, const std::vector<Sample>& samples) const {
        long double total = 0.0L;
        for (const auto& sample : samples) {
            double prediction = eval_tree(root, sample.x);
            if (!std::isfinite(prediction)) {
                return kFitnessPenalty;
            }
            double diff = prediction - sample.y;
            if (!std::isfinite(diff)) {
                return kFitnessPenalty;
            }
            total += static_cast<long double>(diff) * static_cast<long double>(diff);
            if (total > static_cast<long double>(kFitnessPenalty)) {
                return kFitnessPenalty;
            }
        }
        return static_cast<double>(total / samples.size());
    }

    bool is_valid_tree(const ExprNode* root) const {
        if (root == nullptr) {
            return false;
        }
        // 同时限制深度和节点数,控制搜索空间规模与表达式复杂度。
        return node_count(root) <= max_tree_nodes_ && tree_depth(root) <= max_tree_depth_;
    }

    static std::unique_ptr<ExprNode> make_variable() {
        auto node = std::make_unique<ExprNode>();
        node->type = NodeType::Variable;
        return node;
    }

    static std::unique_ptr<ExprNode> make_constant(double value) {
        auto node = std::make_unique<ExprNode>();
        node->type = NodeType::Constant;
        node->constant = value;
        return node;
    }

    static std::unique_ptr<ExprNode> make_unary(NodeType type, std::unique_ptr<ExprNode> child) {
        auto node = std::make_unique<ExprNode>();
        node->type = type;
        node->left = std::move(child);
        return node;
    }

    static std::unique_ptr<ExprNode> make_binary(NodeType type,
                                                 std::unique_ptr<ExprNode> lhs,
                                                 std::unique_ptr<ExprNode> rhs) {
        auto node = std::make_unique<ExprNode>();
        node->type = type;
        node->left = std::move(lhs);
        node->right = std::move(rhs);
        return node;
    }

    std::vector<Sample> train_samples_;
    std::vector<Sample> test_samples_;
    int population_size_ = 240;
    int generations_ = 320;
    int elite_count_ = 10;
    int tournament_size_ = 5;
    int max_tree_depth_ = 8;
    int max_tree_nodes_ = 40;
    double crossover_rate_ = 0.90;
    double mutation_rate_ = 0.38;
    double complexity_penalty_ = 2e-3;
    std::mt19937 rng_;
    ExpressionFactory factory_;
    std::uniform_real_distribution<double> prob_{0.0, 1.0};
};

std::unique_ptr<ExprNode> make_variable_node() {
    auto node = std::make_unique<ExprNode>();
    node->type = NodeType::Variable;
    return node;
}

std::unique_ptr<ExprNode> make_constant_node(double value) {
    auto node = std::make_unique<ExprNode>();
    node->type = NodeType::Constant;
    node->constant = value;
    return node;
}

std::unique_ptr<ExprNode> make_unary_node(NodeType type, std::unique_ptr<ExprNode> child) {
    auto node = std::make_unique<ExprNode>();
    node->type = type;
    node->left = std::move(child);
    return node;
}

std::unique_ptr<ExprNode> make_binary_node(NodeType type,
                                           std::unique_ptr<ExprNode> lhs,
                                           std::unique_ptr<ExprNode> rhs) {
    auto node = std::make_unique<ExprNode>();
    node->type = type;
    node->left = std::move(lhs);
    node->right = std::move(rhs);
    return node;
}

// 这个函数只用于生成演示用的隐藏目标表达式。
// 它相当于"出题人":先随机挑一个模板,再把常数参数填进去,
// 最后由遗传算法根据采样点把这个目标函数反推回来。
std::unique_ptr<ExprNode> generate_target_expression(std::mt19937& rng) {
    static const std::vector<double> linear_pool = {-1.5, -1.0, -0.5, 0.5, 1.0, 1.5};
    static const std::vector<double> denom_pool = {1.5, 2.0, 3.0};

    std::uniform_int_distribution<int> pick_template(0, 4);
    std::uniform_int_distribution<int> pick_linear(0, static_cast<int>(linear_pool.size()) - 1);
    std::uniform_int_distribution<int> pick_denom(0, static_cast<int>(denom_pool.size()) - 1);

    const double a = linear_pool[pick_linear(rng)];
    const double b = denom_pool[pick_denom(rng)];

    // Keep the hidden target family simple enough that the demo usually
    // recovers a recognizable expression within a few hundred generations.
    switch (pick_template(rng)) {
        case 0:
            return make_binary_node(
                NodeType::Add,
                make_unary_node(NodeType::Sin, make_variable_node()),
                make_binary_node(NodeType::Mul, make_constant_node(a), make_variable_node()));
        case 1:
            return make_binary_node(
                NodeType::Sub,
                make_unary_node(NodeType::Cos, make_variable_node()),
                make_binary_node(NodeType::Div, make_variable_node(), make_constant_node(b)));
        case 2:
            return make_binary_node(
                NodeType::Add,
                make_binary_node(NodeType::Mul, make_variable_node(), make_variable_node()),
                make_binary_node(NodeType::Mul,
                                 make_constant_node(a),
                                 make_unary_node(NodeType::Sin, make_variable_node())));
        case 3:
            return make_binary_node(
                NodeType::Sub,
                make_binary_node(NodeType::Mul, make_variable_node(), make_variable_node()),
                make_binary_node(NodeType::Mul,
                                 make_constant_node(b),
                                 make_unary_node(NodeType::Cos, make_variable_node())));
        default:
            return make_binary_node(
                NodeType::Add,
                make_unary_node(NodeType::Sin,
                                make_binary_node(NodeType::Add,
                                                 make_variable_node(),
                                                 make_constant_node(0.5))),
                make_binary_node(NodeType::Div, make_variable_node(), make_constant_node(b)));
    }
}

std::vector<Sample> make_train_samples(const ExprNode* target) {
    std::vector<Sample> samples;
    // 训练集使用较稀疏的采样点,GA 进化时只看得到这些数据。
    for (int i = 0; i < 21; ++i) {
        double x = -2.5 + 5.0 * i / 20.0;
        samples.push_back({x, eval_tree(target, x)});
    }
    return samples;
}

std::vector<Sample> make_test_samples(const ExprNode* target) {
    std::vector<Sample> samples;
    // 测试集使用更密且略微错开的采样点,用来看泛化而不是死记训练点。
    for (int i = 0; i < 40; ++i) {
        double x = -2.4 + 4.8 * i / 39.0;
        samples.push_back({x, eval_tree(target, x)});
    }
    return samples;
}

double mean_absolute_error(const ExprNode* lhs,
                           const ExprNode* rhs,
                           const std::vector<double>& xs) {
    double total = 0.0;
    for (double x : xs) {
        total += std::abs(eval_tree(lhs, x) - eval_tree(rhs, x));
    }
    return total / xs.size();
}

double max_absolute_error(const ExprNode* lhs,
                          const ExprNode* rhs,
                          const std::vector<double>& xs) {
    double best = 0.0;
    for (double x : xs) {
        best = std::max(best, std::abs(eval_tree(lhs, x) - eval_tree(rhs, x)));
    }
    return best;
}

}  // namespace

// 程序流程:
// 1. 固定随机种子,生成一个可复现实验的目标函数;
// 2. 构造训练集与测试集;
// 3. 配置遗传算法参数并开始进化;
// 4. 输出恢复出来的表达式以及误差指标。
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

//    const unsigned target_seed = 20260515U;
    const unsigned target_seed = 5;
    const unsigned evolution_seed = 20260516U;

    std::mt19937 target_rng(target_seed);
    // 先随机生成一个隐藏目标表达式,再从这个表达式采样出数据集。
    auto target = generate_target_expression(target_rng);
    auto train_samples = make_train_samples(target.get());
    auto test_samples = make_test_samples(target.get());

    std::cout << std::fixed << std::setprecision(6);
    std::cout << "Symbolic regression demo based on genetic programming\n";
    std::cout << "primitive set: +, -, *, /, sin, cos\n";
    std::cout << "target seed: " << target_seed << '\n';
    std::cout << "evolution seed: " << evolution_seed << '\n';
    std::cout << "target expression: " << to_string(target.get()) << "\n\n";

    std::cout << "training samples (only data passed to GA):\n";
    for (const auto& sample : train_samples) {
        std::cout << "x=" << std::setw(8) << sample.x << "  y=" << std::setw(12) << sample.y
                  << '\n';
    }
    std::cout << '\n';

    // 这里的参数控制种群规模、迭代代数,以及交叉/变异强度。
    SymbolicRegressionGA solver(train_samples, test_samples, evolution_seed);
    solver.set_population_size(260);
    solver.set_generations(360);
    solver.set_elite_count(12);
    solver.set_tournament_size(6);
    solver.set_crossover_rate(0.92);
    solver.set_mutation_rate(0.42);
    solver.set_max_tree_depth(8);
    solver.set_max_tree_nodes(42);

    auto best = solver.run();

    std::vector<double> compare_xs;
    for (int i = 0; i < 81; ++i) {
        compare_xs.push_back(-2.5 + 5.0 * i / 80.0);
    }

    std::cout << "\nrecovered expression: " << to_string(best.root.get()) << '\n';
    std::cout << "train mse: " << best.train_mse << '\n';
    std::cout << "test mse:  " << best.test_mse << '\n';
    std::cout << "mean abs error on dense grid: "
              << mean_absolute_error(target.get(), best.root.get(), compare_xs) << '\n';
    std::cout << "max  abs error on dense grid: "
              << max_absolute_error(target.get(), best.root.get(), compare_xs) << '\n';

    std::cout << "\ncomparison on representative points:\n";
    for (int i = 0; i <= 10; ++i) {
        double x = -2.5 + 5.0 * i / 10.0;
        double target_y = eval_tree(target.get(), x);
        double predicted_y = eval_tree(best.root.get(), x);
        std::cout << "x=" << std::setw(8) << x << "  target=" << std::setw(12) << target_y
                  << "  predicted=" << std::setw(12) << predicted_y
                  << "  abs_diff=" << std::setw(12) << std::abs(target_y - predicted_y) << '\n';
    }

    return 0;
}
相关推荐
卓豪终端管理5 天前
告别数据残留:如何为退役终端选择正确的“清除”方式
支持向量机·启发式算法
开开心心就好10 天前
一键扫描电脑重复文件的实用工具
linux·运维·服务器·随机森林·智能手机·excel·启发式算法
天辛大师11 天前
天辛大师谈人工智能时代,如何用AI研究历代放生劝善忏悔文
大数据·人工智能·随机森林·启发式算法
apollowing12 天前
启发式算法WebApp实验室:从搜索策略到群体智能的能力进阶(上)
算法·启发式算法·web app
开开心心就好13 天前
仅168KB的桌面图标自动隐藏工具
windows·计算机视觉·计算机外设·excel·启发式算法·宽度优先·csdn开发云
apollowing16 天前
启发式算法WebApp实验室:从搜索策略到群体智能的能力进阶(二十二)
算法·启发式算法·web app
天辛大师17 天前
AI助力旅游扩大化,五一旅游公园通游年票普惠研究
大数据·启发式算法·旅游
apollowing17 天前
启发式算法WebApp实验室:从搜索策略到群体智能的能力进阶(优)
算法·启发式算法·web app
zB6822HbX23 天前
Ledger官方授权正式落地中国大陆,京东独家首发开启安全新纪元
安全·启发式算法·ai写作