Informed RRT*实现椭圆启发式采样

我来为您添加Informed RRT*功能,在找到第一条路径后使用椭圆采样来加速收敛。以下是需要新增的函数和修改:

以下代码只含新增的部分

cpp 复制代码
class RRTStar {
private:
    // 在私有成员变量中添加
    bool pathFound;
    double bestPathCost;
    std::shared_ptr<Node> goalNode;
    
    // ... 其他现有成员
    
public:
    // 在构造函数中初始化新变量
    RRTStar(const Point& start, const Point& goal, 
            double stepSize = 0.5, double goalRadius = 0.5, 
            double searchRadius = 1.0, int maxIterations = 5000)
        : start(start), goal(goal), stepSize(stepSize), 
          goalRadius(goalRadius), searchRadius(searchRadius),
          maxIterations(maxIterations), nodeCounter(0), gen(rd()),
          pathFound(false), bestPathCost(std::numeric_limits<double>::max()),
          goalNode(nullptr) {
        
        minX = std::min(start.x, goal.x) - 5;
        maxX = std::max(start.x, goal.x) + 5;
        minY = std::min(start.y, goal.y) - 5;
        maxY = std::max(start.y, goal.y) + 5;
        
        xDist = std::uniform_real_distribution<>(minX, maxX);
        yDist = std::uniform_real_distribution<>(minY, maxY);
        
        addNode(std::make_shared<Node>(0, start, -1, 0.0));
    }
    
    // 新增函数:生成椭圆内的随机点
    Point generateEllipseRandomPoint(double cBest) {
        // cBest是当前最佳路径长度
        // cMin是起点到终点的直线距离(椭圆焦距距离)
        double cMin = start.distanceTo(goal);
        
        // 如果cBest小于cMin或未找到路径,返回整个空间的随机点
        if (cBest < cMin || !pathFound) {
            return generateRandomPoint(0.1);
        }
        
        // 计算椭圆中心、旋转角和半轴长度
        Point center(
            (start.x + goal.x) / 2.0,
            (start.y + goal.y) / 2.0
        );
        
        // 椭圆的长轴长度
        double a = cBest / 2.0;
        
        // 计算短轴长度:b = sqrt(cBest² - cMin²) / 2
        double b = std::sqrt(std::pow(cBest, 2) - std::pow(cMin, 2)) / 2.0;
        
        // 计算旋转角(从起点指向终点)
        double theta = std::atan2(goal.y - start.y, goal.x - start.x);
        
        // 在单位圆内生成随机点
        std::uniform_real_distribution<> angleDist(0.0, 2.0 * M_PI);
        std::uniform_real_distribution<> radiusDist(0.0, 1.0);
        
        double r = std::sqrt(radiusDist(gen)); // 均匀分布在圆内
        double phi = angleDist(gen);
        
        // 将单位圆内的点映射到椭圆
        double x = r * std::cos(phi);
        double y = r * std::sin(phi);
        
        // 应用椭圆变换:缩放、旋转、平移
        double transformedX = center.x + a * x * std::cos(theta) - b * y * std::sin(theta);
        double transformedY = center.y + a * x * std::sin(theta) + b * y * std::cos(theta);
        
        return Point(transformedX, transformedY);
    }
    
    // 修改后的generateRandomPoint函数,添加椭圆采样逻辑
    Point generateRandomPoint(double goalBias = 0.1) {
        // 如果已经找到路径,使用椭圆采样
        if (pathFound) {
            // 10%的概率使用原始采样策略,90%的概率使用椭圆采样
            std::uniform_real_distribution<> strategyDist(0.0, 1.0);
            if (strategyDist(gen) < 0.9) {
                return generateEllipseRandomPoint(bestPathCost);
            }
        }
        
        // 原始采样策略
        std::uniform_real_distribution<> biasDist(0.0, 1.0);
        
        if (biasDist(gen) < goalBias) {
            return goal;
        }
        
        return Point(xDist(gen), yDist(gen));
    }
    
    // 新增函数:更新最佳路径
    void updateBestPath(std::shared_ptr<Node> newGoalNode) {
        if (newGoalNode) {
            double newCost = newGoalNode->getCost();
            
            if (newCost < bestPathCost) {
                bestPathCost = newCost;
                goalNode = newGoalNode;
                pathFound = true;
                
                std::cout << "更新最佳路径,成本: " << bestPathCost << std::endl;
                
                // 动态调整搜索半径(可选)
                // searchRadius = std::min(10.0 * stepSize, 
                //                       std::max(stepSize, bestPathCost / 10.0));
            }
        }
    }
    
    // 修改后的plan函数,包含Informed RRT*逻辑
    bool plan() {
        for (int i = 0; i < maxIterations; i++) {
            // 1. 生成随机点(现在会根据是否找到路径选择采样策略)
            Point randPoint = generateRandomPoint();
            
            // 2. 找到最近的节点
            auto nearestNode = findNearestNode(randPoint);
            
            // 3. 生成新节点
            Point newPoint = steer(nearestNode->getPosition(), randPoint);
            
            // 4. 检查路径是否无碰撞
            if (!isPathCollisionFree(nearestNode->getPosition(), newPoint)) {
                continue;
            }
            
            // 5. 查找附近的节点
            auto nearNodes = findNearNodes(newPoint, searchRadius);
            
            // 6. 选择最佳父节点
            std::shared_ptr<Node> bestParent = nearestNode;
            double minCost = nearestNode->getCost() + 
                           nearestNode->getPosition().distanceTo(newPoint);
            
            for (auto& node : nearNodes) {
                double newCost = node->getCost() + node->getPosition().distanceTo(newPoint);
                if (newCost < minCost && isPathCollisionFree(node->getPosition(), newPoint)) {
                    minCost = newCost;
                    bestParent = node;
                }
            }
            
            // 7. 创建新节点
            int newNodeId = ++nodeCounter;
            auto newNode = std::make_shared<Node>(newNodeId, newPoint, 
                                                  bestParent->getId(), minCost);
            addNode(newNode);
            
            // 8. 重新连接
            for (auto& node : nearNodes) {
                if (node->getId() == bestParent->getId()) {
                    continue;
                }
                
                double potentialCost = newNode->getCost() + 
                                      newNode->getPosition().distanceTo(node->getPosition());
                
                if (potentialCost < node->getCost() && 
                    isPathCollisionFree(newNode->getPosition(), node->getPosition())) {
                    node->setParentId(newNode->getId());
                    node->setCost(potentialCost);
                    
                    // 如果被重新连接的节点是目标节点的祖先,需要更新目标节点成本
                    if (goalNode && isDescendantOf(node, goalNode)) {
                        updateGoalNodeCost();
                    }
                }
            }
            
            // 9. 检查是否到达目标
            if (newPoint.distanceTo(goal) <= goalRadius) {
                if (isPathCollisionFree(newPoint, goal)) {
                    // 添加目标节点
                    int goalNodeId = ++nodeCounter;
                    double goalCost = newNode->getCost() + newPoint.distanceTo(goal);
                    auto newGoalNode = std::make_shared<Node>(
                        goalNodeId, goal, newNode->getId(), goalCost);
                    addNode(newGoalNode);
                    
                    // 更新最佳路径
                    updateBestPath(newGoalNode);
                    
                    if (!pathFound) {
                        std::cout << "找到第一条路径!迭代次数: " << i + 1 
                                  << ",成本: " << goalCost << std::endl;
                        pathFound = true;
                        bestPathCost = goalCost;
                        goalNode = newGoalNode;
                    }
                }
            }
            
            // 每100次迭代打印进度
            if ((i + 1) % 1000 == 0) {
                std::cout << "迭代 " << i + 1;
                if (pathFound) {
                    std::cout << ",当前最佳路径成本: " << bestPathCost;
                }
                std::cout << std::endl;
            }
        }
        
        if (pathFound) {
            std::cout << "\n最终最佳路径成本: " << bestPathCost << std::endl;
            return true;
        } else {
            std::cout << "未找到路径,达到最大迭代次数" << std::endl;
            return false;
        }
    }
    
    // 新增函数:检查一个节点是否是另一个节点的后代
    bool isDescendantOf(std::shared_ptr<Node> ancestor, std::shared_ptr<Node> descendant) {
        auto current = descendant;
        while (current) {
            if (current->getId() == ancestor->getId()) {
                return true;
            }
            if (current->getParentId() == -1) {
                break;
            }
            current = tree[current->getParentId()];
        }
        return false;
    }
    
    // 新增函数:更新目标节点成本(当重新连接影响目标节点路径时)
    void updateGoalNodeCost() {
        if (!goalNode) return;
        
        double newCost = 0.0;
        auto current = goalNode;
        std::vector<std::shared_ptr<Node>> pathNodes;
        
        // 收集路径上的所有节点
        while (current) {
            pathNodes.push_back(current);
            if (current->getParentId() == -1) {
                break;
            }
            current = tree[current->getParentId()];
        }
        
        // 反向计算累积成本
        std::reverse(pathNodes.begin(), pathNodes.end());
        newCost = 0.0;
        
        for (size_t i = 0; i < pathNodes.size(); i++) {
            if (i > 0) {
                newCost += pathNodes[i-1]->getPosition().distanceTo(pathNodes[i]->getPosition());
                pathNodes[i]->setCost(newCost);
            } else {
                pathNodes[i]->setCost(0.0);
            }
        }
        
        // 如果成本有改善,更新最佳路径
        if (newCost < bestPathCost) {
            bestPathCost = newCost;
            std::cout << "通过重新连接优化路径,新成本: " << bestPathCost << std::endl;
        }
    }
    
    // 修改后的getPath函数
    std::vector<Point> getPath() {
        std::vector<Point> path;
        
        if (!goalNode) {
            // 如果没有设置goalNode,尝试查找
            for (const auto& pair : tree) {
                if (pair.second->getPosition().distanceTo(goal) <= goalRadius) {
                    goalNode = pair.second;
                    break;
                }
            }
        }
        
        if (!goalNode) {
            return path;
        }
        
        // 反向追踪路径
        auto current = goalNode;
        while (current) {
            path.push_back(current->getPosition());
            if (current->getParentId() == -1) {
                break;
            }
            current = tree[current->getParentId()];
        }
        
        std::reverse(path.begin(), path.end());
        return path;
    }
    
    // 新增函数:获取当前最佳路径成本
    double getBestPathCost() const {
        return bestPathCost;
    }
    
    // 新增函数:检查是否已找到路径
    bool isPathFound() const {
        return pathFound;
    }
    
    // 新增函数:获取椭圆采样区域信息(用于可视化)
    void getEllipseInfo(double& a, double& b, double& theta, Point& center) const {
        double cMin = start.distanceTo(goal);
        center = Point((start.x + goal.x) / 2.0, (start.y + goal.y) / 2.0);
        a = bestPathCost / 2.0;
        b = std::sqrt(std::pow(bestPathCost, 2) - std::pow(cMin, 2)) / 2.0;
        theta = std::atan2(goal.y - start.y, goal.x - start.x);
    }
};

主要新增功能和修改:

  1. 新增成员变量

· pathFound: 标记是否已找到路径

· bestPathCost: 当前最佳路径成本

· goalNode: 指向目标节点的指针

  1. 椭圆采样函数

· generateEllipseRandomPoint(): 在连接起点和终点的椭圆内生成随机点

· 椭圆焦点:起点和终点

· 椭圆长轴长度:当前最佳路径长度

· 短轴长度:根据勾股定理计算

  1. 改进的采样策略

· generateRandomPoint(): 在找到路径后,90%的概率使用椭圆采样,10%的概率使用原始采样

· 这确保了在找到路径后,搜索集中在有希望找到更好路径的区域

  1. 路径更新机制

· updateBestPath(): 当找到新的更好路径时更新状态

· updateGoalNodeCost(): 当重新连接影响目标节点路径时更新成本

  1. 进度监控

· 每1000次迭代打印当前状态

· 显示当前最佳路径成本

  1. 辅助函数

· isDescendantOf(): 检查节点间的祖孙关系

· getEllipseInfo(): 获取椭圆采样区域信息(可用于可视化)

使用示例:

cpp 复制代码
int main() {
    Point start(0, 0);
    Point goal(10, 10);
    
    RRTStar planner(start, goal, 0.5, 0.5, 1.5, 20000);
    
    // ... 添加障碍物和设置边界
    
    bool success = planner.plan();
    
    if (success) {
        std::vector<Point> path = planner.getPath();
        std::cout << "最终路径成本: " << planner.getBestPathCost() << std::endl;
        
        // 获取椭圆信息(可用于可视化)
        double a, b, theta;
        Point center;
        if (planner.isPathFound()) {
            planner.getEllipseInfo(a, b, theta, center);
            std::cout << "椭圆采样区域 - 中心: (" << center.x << ", " << center.y 
                      << "), 长轴: " << a << ", 短轴: " << b 
                      << ", 旋转角: " << theta << " rad" << std::endl;
        }
    }
    
    return 0;
}

这个Informed RRT实现能够在找到第一条路径后显著加速优化过程,通过椭圆采样将搜索集中在有希望找到更好路径的区域,从而比标准RRT更快收敛到最优路径。

相关推荐
Swizard4 小时前
告别样本不平衡噩梦:Focal Loss 让你的模型学会“划重点”
算法·ai·训练
CoderCodingNo4 小时前
【GESP】C++一级真题 luogu-B4410 [GESP202509 一级] 金字塔
开发语言·c++
超级大福宝4 小时前
C++中1 << 31 - 1相当于INT_MAX吗?
c语言·c++
亭台4 小时前
【Matlab笔记_23】MATLAB的工具包m_map的m_image和m_pcolor区别
笔记·算法·matlab
李玮豪Jimmy4 小时前
Day39:动态规划part12(115.不同的子序列、583.两个字符串的删除操作、72.编辑距离)
算法·动态规划
alibli4 小时前
一文学会设计模式之结构型模式及最佳实现
c++·设计模式
A7bert7774 小时前
【YOLOv5seg部署RK3588】模型训练→转换RKNN→开发板部署
linux·c++·人工智能·深度学习·yolo·目标检测
历程里程碑5 小时前
C++ 10 模板进阶:参数特化与分离编译解析
c语言·开发语言·数据结构·c++·算法
老秦包你会5 小时前
C++进阶------智能指针和特殊类设计方式
开发语言·c++