路径规划——RRT-Connect算法

路径规划------RRT-Connect算法

算法原理

RRT-Connect算法是在RRT算法的基础上进行的扩展,引入了双树生长,分别以起点和目标点为树的根节点同时扩展随机树从而实现对状态空间的快速搜索。在此算法中以两棵随机树建立连接为路径规划成功的条件。并且,在搜索过程中使用了贪婪搜索的方法,在搜索的过程中,两棵树是交替扩展的,与RRT算法不同的是,RRT-Connect算法并不是每次扩展都会进行随机采样,而是第一棵树先随机采样进而扩展一个新的节点node_new,然后第二棵树利用node_new节点往相同的方向进行多次扩展直到扩展失败才会开始下一轮的交替扩展或者与另一棵树能够建立连接了从而满足路径规划完成的条件。

这种双向的RRT算法比原始RRT算法的搜索速度更快,因为除了使用双树扩展搜索,两棵树在扩展时还是朝着对方的方向进行扩展的,并不是完全随机的。

具体的算法流程可结合上述原理以及RRT算法的实现流程。

算法实现

python 复制代码
"""
  @File: rrt-connect.py
  @Brief: RRT-Connect algorithm for pathplanning
 
  @Author: Benxiaogu
  @Github: https://github.com/Benxiaogu
  @CSDN: https://blog.csdn.net/weixin_51995147?type=blog
 
  @Date: 2024-11-13
"""
import numpy as np
import random
import math
from itertools import combinations
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as patches

class RRTConnect:
    def __init__(self,start,goal,obstacles,board_size,max_try,max_dist,goal_sample_rate,env) -> None:
        self.start = self.Node(start,None,0)
        self.goal = self.Node(goal,None,0)
        self.obstacles = obstacles
        self.board_size = board_size
        self.max_try = max_try # Number of iterations
        self.max_dist = max_dist # Maximum sampling distance
        self.goal_sample_rate = goal_sample_rate
        self.env = env
        self.inflation = 1
        self.searched = []

    class Node:
        def __init__(self,position,parent,cost) -> None:
            self.position = position
            self.parent = parent
            self.cost = cost


    def run(self):
        cost,path,expand = self.plan()
        self.searched = expand
        self.visualize(cost,path)

    def plan(self):
        nodes_forward = {self.start.position: self.start}
        nodes_back = {self.goal.position: self.goal}

        for iter in range(self.max_try):
            # Generate a random node
            node_rand = self.get_random_node()
            # Get the nearest neighbor node
            node_near = self.get_nearest_neighbor(list(nodes_forward.values()),node_rand)
            # Get the new node
            node_new = self.get_new_node(node_rand,node_near)
            
            if node_new:
                nodes_forward[node_new.position] = node_new
                node_near_b = self.get_nearest_neighbor(list(nodes_back.values()), node_new)
                node_new_b = self.get_new_node(node_new,node_near_b)

                if node_new_b:
                    nodes_back[node_new_b.position] = node_new_b
                    # Greedy extending
                    while True:
                        for node_position, node in nodes_back.items():
                            if node.position == node_new.position:
                                print("final")
                                cost, path = self.extractPath(node_new, nodes_back, nodes_forward)
                                expand = self.get_expand(list(nodes_back.values()), list(nodes_forward.values()))
                                print("Exploring {} nodes.".format(iter))
                                return cost, path, expand

                        node_new_b2 = self.get_new_node(node_new,node_new_b)

                        if node_new_b2:
                            nodes_back[node_new_b2.position] = node_new_b2
                            node_new_b = node_new_b2
                        else:
                            break

            if len(nodes_back) < len(nodes_forward):
                nodes_forward, nodes_back = nodes_back, nodes_forward

        return 0, None, None

    def get_random_node(self):
        """
            Return a random node.
        """
        if random.random()>self.goal_sample_rate:
            node = self.Node(
                (random.uniform(0,self.env.height),random.uniform(0,self.env.width)),None,0)
        else:
            node = self.goal
            
        return node

    def get_nearest_neighbor(self,node_list,node):
        """
            Return node that is nearest to 'node' in node_list
        """
        dist = [self.distance(node, n) for n in node_list]
        node_near = node_list[int(np.argmin(dist))]
        return node_near
    
    def get_new_node(self,node_rand,node_near):
        """
            Return node found based on node_near and node_rand.
        """
        dx = node_rand.position[0] - node_near.position[0]
        dy = node_rand.position[1] - node_near.position[1]
        dist = math.hypot(dx,dy)
        theta = math.atan2(dy, dx)

        d = min(self.max_dist,dist)
        position = ((node_near.position[0]+d*math.cos(theta)),node_near.position[1]+d*math.sin(theta))

        node_new = self.Node(position,node_near,node_near.cost+d)

        if self.isCollision(node_new, node_near):
            return None
        return node_new
    
    def isCollision(self,node1,node2):
        """
            Judge collision from node1 to node2 
        """
        if self.isInObstacles(node1) or self.isInObstacles(node2):
            return True
        
        for rect in self.env.obs_rectangle:
            if self.isInterRect(node1,node2,rect):
                return True
        
        for circle in self.env.obs_circle:
            if self.isInterCircle(node1,node2,circle):
                return True
        
        return False
    
    def distance(self,node1,node2):
        dx = node2.position[0] - node1.position[0]
        dy = node2.position[1] - node1.position[1]

        return math.hypot(dx,dy)
        
    def isInObstacles(self,node):
        """
            Determine whether it is in obstacles or not.
        """
        x,y = node.position[0],node.position[1]
        for (ox,oy,w,h) in self.env.boundary:
            if ox-self.inflation<x<ox+w+self.inflation and oy-self.inflation<y<oy+h+self.inflation:
                return True
        
        for (ox,oy,w,h) in self.env.obs_rectangle:
            if ox-self.inflation<x<ox+w+self.inflation and oy-self.inflation<y<oy+h+self.inflation:
                return True
            
        for (ox,oy,r) in self.env.obs_circle:
            if math.hypot(x-ox,y-oy)<=r+self.inflation:
                return True
        
        return False
    
    def isInterRect(self,node1,node2,rect):
        """"
            Judge whether it will cross the rectangle when moving from node1 to node2
        
        """
        ox,oy,w,h = rect
        vertex = [[ox-self.inflation,oy-self.inflation],
                  [ox+w+self.inflation,oy-self.inflation],
                  [ox+w+self.inflation,oy+h+self.inflation],
                  [ox-self.inflation,oy+h+self.inflation]]
        
        x1,y1 = node1.position
        x2,y2 = node2.position

        def cross(p1,p2,p3):
            x1 = p2[0]-p1[0]
            y1 = p2[1]-p1[1]
            x2 = p3[0]-p1[0]
            y2 = p3[1]-p1[0]
            return x1*y2 - x2*y1
        
        for v1,v2 in combinations(vertex,2):
            if max(x1,x2) >= min(v1[0],v2[0]) and min(x1,x2)<=max(v1[0],v2[0]) and \
                max(y1,y2) >= min(v1[1],v2[1]) and min(y1,y2) <= max(v1[1],v2[1]):
                if cross(v1,v2,node1.position) * cross(v1,v2,node2.position)<=0 and \
                    cross(node1.position,node2.position,v1) * cross(node1.position,node2.position,v2)<=0:
                    return True
    
        return False
    
    def isInterCircle(self,node1,node2,circle):
        """
            Judge whether it will cross the circle when moving from node1 to node2
        """
        ox,oy,r = circle

        dx = node2.position[0] - node1.position[0]
        dy = node2.position[1] - node1.position[1]

        # print("isInterCircle-dx:",dx)
        # print("isInterCircle-dy:",dy)
        d = dx * dx + dy * dy
        if d==0:
            return False
            
        # Projection
        t = ((ox - node1.position[0]) * dx + (oy - node1.position[1]) * dy) / d
        
        # The projection point is on line segment AB
        if 0 <= t <= 1:
            closest_x = node1.position[0] + t * dx
            closest_y = node1.position[1] + t * dy
        
            # Distance from center of the circle to line segment AB
            distance = math.hypot(ox-closest_x,oy-closest_y)
            
            return distance <= r+self.inflation
        
        return False

    def extractPath(self, node_middle, nodes_back, nodes_forward):
        """"
            Extract the path based on the closed set.
        """
        if self.start.position in nodes_back:
            nodes_forward, nodes_back = nodes_back, nodes_forward

        # forward
        node = nodes_forward[node_middle.position]
        path_forward = [node.position]
        cost = node.cost
        while node.position != self.start.position:
            node_parent = nodes_forward[node.parent.position]
            node = node_parent
            path_forward.append(node.position)

        # backward
        node = nodes_back[node_middle.position]
        path_back = []
        cost += node.cost
        while node.position != self.goal.position:
            node_parent = nodes_back[node.parent.position]
            node = node_parent
            path_back.append(node.position)
        
        path = list(reversed(path_forward))+path_back

        return cost, path
    
    def get_expand(self, nodes_back, nodes_forward):
        expand = []
        tree_size = max(len(nodes_forward), len(nodes_back))
        for tr in range(tree_size):
            if tr < len(nodes_forward):
                expand.append(nodes_forward[tr])
            if tr < len(nodes_back):
                expand.append(nodes_back[tr])

        return expand
    
    def visualize(self, cost, path):
        """
            Plot the map.
        """
        ...

完整代码:PathPlanning

相关推荐
小猿_002 小时前
C语言程序设计十大排序—插入排序
c语言·算法·排序算法
熊文豪4 小时前
深入解析人工智能中的协同过滤算法及其在推荐系统中的应用与优化
人工智能·算法
siy23337 小时前
[c语言日寄]结构体的使用及其拓展
c语言·开发语言·笔记·学习·算法
吴秋霖7 小时前
最新百应abogus纯算还原流程分析
算法·abogus
灶龙8 小时前
浅谈 PID 控制算法
c++·算法
菜还不练就废了8 小时前
蓝桥杯算法日常|c\c++常用竞赛函数总结备用
c++·算法·蓝桥杯
金色旭光8 小时前
目标检测高频评价指标的计算过程
算法·yolo
he101018 小时前
1/20赛后总结
算法·深度优先·启发式算法·广度优先·宽度优先
Kent_J_Truman9 小时前
【回忆迷宫——处理方法+DFS】
算法
paradoxjun9 小时前
落地级分类模型训练框架搭建(1):resnet18/50和mobilenetv2在CIFAR10上测试结果
人工智能·深度学习·算法·计算机视觉·分类