
9.3.4 实现RRT、RRT*和RRT*-FN算法
文件algorithm.py实现了多种RRT路径规划算法,包括RRT、RRT*和RRT*-FN。该算法利用随机采样和树结构构建来探索环境中的可行路径,并通过优化树结构来改进路径的质量。该代码通过图形表示节点和边,并提供了可视化功能来显示算法的执行过程和执行时间。通过迭代的方式,逐步优化树结构,直到找到最优路径或达到最大迭代次数为止。
(1)time_function是一个装饰器,用于计时函数执行时间。当应用于一个函数时,它会在函数执行前记录开始时间,然后在函数执行完毕后记录结束时间,并计算函数执行时间。最后,它会打印出函数执行时间,并返回函数的结果。
python
def time_function(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = perf_counter()
args = func(*args, **kwargs)
end_time = perf_counter()
total_time = round(end_time - start_time, 2)
print(f"Execution time of {func.__name__}: {total_time}")
return args
return wrapper
(2)函数new_node用于生成新的节点,并返回新节点的位置、新节点的ID、最近节点的ID以及新节点与最近节点的距离。函数new_node的具体实现流程如下所示。
- 生成一个随机节点 q_rand,检查随机节点是否在障碍物内,如果是,则返回 None 值。
- 创建KD树(2D树)以加速搜索最近节点的过程,并将其传递给 nearest_node_kdtree 函数。
- 寻找离随机节点最近的节点 q_near,并获取其ID id_near。
- 如果找不到最近节点或者随机节点与最近节点重合,则返回 None 值。
- 使用 steer 函数生成新节点 q_new,该函数用于生成从最近节点到新节点的路径,并限制路径长度为 step_length。
- 将新节点添加到图中,并计算新节点与最近节点的距离。
- 返回新节点的位置、新节点的ID、最近节点的ID以及新节点与最近节点的距离。
python
def new_node(G: Graph, map: Map, obstacles: list, step_length: float, bias: float):
q_rand = G.random_node(bias=bias) # 生成一个新的随机节点
if map.is_occupied_c(q_rand): # 如果新节点的位置在障碍物内
return None, None, None, None
potential_vertices_list = list(G.vertices.values())
kdtree = cKDTree(
np.array(potential_vertices_list)) # 创建KD树(2D树)并将其传递给nearest_node函数以加快搜索速度
q_near, id_near = nearest_node_kdtree(G, q_rand, obstacles, kdtree=kdtree)
if q_near is None or (q_rand == q_near).all(): # 无法将随机节点连接到最近节点(可能因为随机节点与最近节点重合)
return None, None, None, None
q_new = steer(q_rand, q_near, step_length)
id_new = G.add_vertex(q_new)
distance = calc_distance(q_new, q_near)
return q_new, id_new, id_near, distance
(3)函数RRT实现了RRT(Rapidly-exploring Random Tree)算法,用于在给定地图上生成路径。该算法通过指定的迭代次数,在图中随机生成节点,并尝试连接到最近的节点,直到达到最大迭代次数或找到路径到达目标节点为止。通过可选参数,可以调整算法的行为,例如允许的最大边长、节点半径、偏向目标节点的程度以及是否实时更新算法过程。
python
@time_function
def RRT(G: Graph, iter_num: int, map: Map, step_length: float, node_radius: int, bias: float = .0, live_update: bool = False):
"""
RRT算法。
:param G: 图
:param iter_num: 算法迭代次数
:param map: 地图
:param step_length: 两个节点之间允许的最大边长
:param node_radius: 节点的半径
:param bias: 0-1之间的偏向目标节点的参数
:param live_update: 布尔值,如果要在图中实时更新算法
:return: 迭代次数
"""
pbar = tqdm(total=iter_num)
obstacles = map.obstacles_c
iter = 0
while iter < iter_num:
q_new, id_new, id_near, distance = new_node(G, map, obstacles, step_length, bias)
if q_new is None:
continue
G.add_edge(id_new, id_near, distance)
if check_solution(G, q_new, node_radius):
path, cost = find_path(G, id_new, G.id_vertex[G.start])
plot_path(G, path, "RRT", cost)
break
pbar.update(1)
iter += 1
if live_update:
plt.pause(0.001)
plt.clf()
plot_graph(G, map.obstacles_c)
plt.xlim((-200, 200))
plt.ylim((-200, 200))
pbar.close()
return iter
(4)函数RRT_star实现了RRT*(Rapidly-exploring Random Tree Star)算法,用于在给定地图上生成路径。该算法通过指定的迭代次数,在图中随机生成节点,并尝试连接到最近的节点,直到达到最大迭代次数或找到路径到达目标节点为止。与标准的RRT算法相比,RRT*算法在每次迭代后尝试重新连接节点以改进路径质量,并通过维护一个以节点为中心的半径范围来进行优化。函数的可选参数允许调整算法的行为,例如允许的最大边长、节点半径、偏向目标节点的程度以及是否实时更新算法过程。执行函数RRT_star后,返回迭代次数以及找到的最佳路径及其代价。
python
@time_function
def RRT_star(G, iter_num, map, step_length, radius, node_radius: int, bias=.0, live_update=False) -> tuple:
"""
RRT*算法。
:param G: 图
:param iter_num: 算法的迭代次数
:param map: 地图
:param step_length: 两个节点之间允许的最大边长
:param radius: 重新连接算法将在其上执行的圆形区域的半径
:param node_radius: 节点的半径
:param bias: 0-1之间,偏向目标节点的程度
:param live_update: 在图上实时显示算法的布尔值
:return: 迭代次数,具有最佳代价的最佳路径
"""
pbar = tqdm(total=iter_num)
obstacles = map.obstacles_c
best_edge = None
solution_found = False # 标志是否已经找到解决方案
best_path = {"path": [], "cost": float("inf")} # 具有最小代价的路径及其代价
finish_nodes_of_path = [] # 找到的路径中最后一个节点的ID
iter = 0
while iter < iter_num:
q_new, id_new, id_near, cost_new_near = new_node(G, map, obstacles, step_length, bias)
if q_new is None:
continue
best_edge = (id_new, id_near, cost_new_near)
G.cost[id_new] = cost_new_near # 计算从最近节点到新节点的代价
G.parent[id_new] = id_near
# KDTree 查询的效率优于暴力搜索
kdtree = cKDTree(list(G.vertices.values()))
choose_parent_kdtree(G, q_new, id_new, best_edge, radius, obstacles, kdtree=kdtree)
# choose_parent(G, q_new, id_new, best_edge, radius, obstacles)
G.add_edge(*best_edge)
# 重新连接
rewire_kdtree(G, q_new, id_new, radius, obstacles, kdtree=kdtree)
# rewire(G, q_new, id_new, radius, obstacles)
# 检查解决方案
if check_solution(G, q_new, node_radius):
path, cost = find_path(G, id_new, G.id_vertex[G.start])
finish_nodes_of_path.append(id_new)
solution_found = True
best_path["path"] = path
best_path["cost"] = cost
# plot_path(G, path, "RRT_STAR", cost)
# break
# 更新路径的代价
for node in finish_nodes_of_path:
path, cost = find_path(G, node, G.id_vertex[G.start])
if cost < best_path["cost"]:
best_path["path"] = path
best_path["cost"] = cost
pbar.update(1)
iter += 1
if live_update:
plt.pause(0.001)
plt.clf()
plot_graph(G, map.obstacles_c)
if solution_found:
plot_path(G, best_path["path"], "RRT_STAR", best_path["cost"])
if solution_found:
plot_path(G, best_path["path"], "RRT_STAR", best_path["cost"])
pbar.close()
return iter, best_path
(5)函数RRT_star_FN实现了RRT*FN(Rapidly-exploring Random Tree with Forced Node removal)算法,用于在给定地图上生成路径。在迭代过程中随机生成节点,并尝试将其连接到最近的节点,同时通过强制删除节点来限制节点数量。可通过调整参数来控制算法行为,如最大迭代次数、允许的最大边长、节点半径、偏向目标节点的程度、是否实时更新算法过程以及最大节点数等。
python
@time_function
def RRT_star_FN(G, iter_num, map, step_length, radius, node_radius: int, max_nodes=200, bias=.0,
live_update: bool = False):
"""
RRT star FN算法。
:param G: 图
:param iter_num: 算法的迭代次数
:param map: 地图
:param step_length: 两个节点之间允许的最大边长
:param radius: 重新连接算法将在其上执行的圆形区域的半径
:param node_radius: 节点的半径
:param max_nodes: 最大节点数
:param bias: 0-1之间,偏向目标节点的程度
:param live_update: 在图上实时显示算法的布尔值
:return: 迭代次数,具有最佳代价的最佳路径
"""
pbar = tqdm(total=iter_num)
obstacles = map.obstacles_c
best_edge = None
n_of_nodes = 1 # 初始只有起始节点
solution_found = False # 标志是否已经找到解决方案
best_path = {"path": [], "cost": float("inf")} # 具有最小代价的路径及其代价
finish_nodes_of_path = [] # 找到的路径中最后一个节点的ID
iter = 0
while iter < iter_num:
q_new, id_new, id_near, cost_new_near = new_node(G, map, obstacles, step_length, bias)
if q_new is None:
continue
best_edge = (id_new, id_near, cost_new_near)
G.cost[id_new] = cost_new_near # 计算从最近节点到新节点的代价
G.parent[id_new] = id_near
n_of_nodes += 1
kdtree = cKDTree(list(G.vertices.values()))
choose_parent_kdtree(G, q_new, id_new, best_edge, radius, obstacles, kdtree=kdtree)
# choose_parent(G, q_new, id_new, best_edge, radius, obstacles)
G.add_edge(*best_edge)
# 重新连接
rewire_kdtree(G, q_new, id_new, radius, obstacles, kdtree=kdtree)
# rewire(G, q_new, id_new, radius, obstacles)
# 如有必要,删除随机的无子节点节点
if n_of_nodes > max_nodes:
id_removed = forced_removal(G, id_new, best_path["path"])
if id_removed in finish_nodes_of_path:
finish_nodes_of_path.remove(id_removed)
n_of_nodes -= 1
# 检查解决方案
if check_solution(G, q_new, node_radius):
path, _ = find_path(G, id_new, G.id_vertex[G.start])
finish_nodes_of_path.append(id_new)
solution_found = True
# break
# 更新路径的代价
for node in finish_nodes_of_path:
path, cost = find_path(G, node, G.id_vertex[G.start])
if cost < best_path["cost"]:
best_path["path"] = path
best_path["cost"] = cost
pbar.update(1)
iter += 1
if live_update:
plt.pause(0.001)
plt.clf()
plot_graph(G, map.obstacles_c)
if solution_found:
plot_path(G, best_path["path"], "RRT_STAR_FN", best_path["cost"])
if solution_found:
plot_path(G, best_path["path"], "RRT_STAR_FN", best_path["cost"])
pbar.close()
return iter, best_path
(6)函数select_branch实现了RRT_*_FND算法中的选择分支策略,用于删除不再位于路径上的节点及其子节点。它接收当前达到的节点以及先前的路径作为输入,并根据路径更新图中的节点和边。随着节点的移除,函数会实时显示图的变化。最后,它返回更新后的路径。
python
def select_branch(G: Graph, current_node: int, path: list, map: Map) -> list:
"""
RRT_*_FND算法中使用的选择分支算法。删除所有不再位于路径上的节点及其子节点。
:param G: 图
:param current_node: 已到达的节点,将删除其父节点及其子节点
:param path: 先前的路径
:param map: 地图
:return: 新路径
"""
G.start = G.vertices[current_node]
parent = G.parent[current_node]
G.parent[current_node] = None # 将当前节点的父节点设置为None
del G.children[parent][
G.children[parent].index(current_node)] # 从先前父节点的子节点中删除当前节点
plt.figure()
plot_graph(G, map.obstacles_c)
plt.show()
new_path = path[: path.index(current_node) + 1]
stranded_nodes = path[path.index(current_node) + 1:]
for stranded_node in reversed(stranded_nodes):
if stranded_node == current_node: break
remove_children(G, stranded_node, path)
plt.figure()
plot_graph(G, map.obstacles_c)
plt.show()
for node in stranded_nodes:
G.remove_vertex(node)
return new_path
(7)函数remove_children用于从图中删除指定节点的所有子节点,条件是这些子节点不在当前路径上。
python
def remove_children(G: Graph, id_node: int, path: list) -> None:
"""
如果子节点不在路径上,则删除指定节点的所有子节点。
:param G: 图
:param id_node: 要删除其子节点的节点的ID
:param path: 当前路径
:return: 无
"""
nodes_children = G.children[id_node].copy()
for id_child in nodes_children:
if id_child in path: continue
if len(G.children[id_child]) != 0:
remove_children(G, id_child, path)
G.remove_vertex(id_child)
(8)函数valid_path用于验证路径的有效性,并在路径中发现与障碍物碰撞的节点时,删除这些节点及其子节点,然后返回新分离的树的根节点的ID。
python
def valid_path(G: Graph, path: list, map: Map, previous_root: int) -> int:
"""
ValidPath算法,用于删除与障碍物碰撞的节点以及这些节点的子节点
:param G: 图(Graph)
:param path: 先前的路径
:param map: 地图(Map)
:param previous_root: 先前树的根节点的ID
:return: 新分离树的根节点的ID
"""
id_separate_root = previous_root
nodes_in_path = [(id_node, G.vertices[id_node]) for id_node in path]
for id_node, pos_node in reversed(nodes_in_path):
if map.is_occupied_c(pos_node): # 如果路径中的节点与障碍物相撞,则删除它
remove_children(G, id_node, path)
id_remaining_child = G.children[id_node][0] # 此节点仅剩的子节点在路径中
G.remove_vertex(id_node) # 删除与障碍物相撞的节点
G.parent[id_remaining_child] = None # 将分离根节点的父节点设为None
id_separate_root = id_remaining_child
plt.pause(0.001)
plt.clf()
plot_graph(G, map.obstacles_c)
return id_separate_root
(9)函数reconnect实现了重新连接算法,用于在给定的图中尝试建立两个部分之间的连接。它搜索分离树的根节点附近的原始树节点,并尝试通过路径中不存在障碍物的直线连接它们。如果成功建立了连接,则返回新的路径及其成本。
python
def reconnect(G: Graph, path: list, map: Map, step_size: float, id_root: int, id_separate_root: int,
last_node: int) -> tuple:
"""
重新连接算法,用于尝试在树的两个部分之间建立连接。
:param G: 图(Graph)
:param path: 当前已知路径(尽管不通向目标,因为SearchBranch和ValidPath已经执行)
:param map: 地图(Map)
:param step_size: 步长
:param id_root: 原始树的根节点
:param id_separate_root: 分离树的根节点
:param last_node: 原始路径的最后一个节点
:return: 路径和成本的元组
"""
reconnected = False
close_root_nodes = get_near_nodes(G, id_separate_root, step_size, path[-1]) # 与分离树的根节点接近的原始树的节点
pos_separate = G.vertices[id_separate_root]
for node in close_root_nodes:
line = Line(G.vertices[node], pos_separate)
if not through_obstacle(line, map.obstacles_c):
cost = calc_distance(G.vertices[node], pos_separate)
G.add_edge(id_separate_root, node, cost)
reconnected = True
break
plt.figure()
plot_graph(G, map.obstacles_c)
plt.show()
path_and_cost = None
if reconnected:
path_and_cost = find_path(G, last_node, id_root)
plot_graph(G, map.obstacles_c)
plot_path(G, path_and_cost[0], "重新连接后", path_and_cost[1])
return path_and_cost
(10)函数check_for_tree_associativity用于检查节点是否属于以给定根节点为根的树,通过沿着节点的父节点链逐步向上检查,直到到达根节点或者没有父节点时停止。如果最终节点等于给定的根节点,则返回True,表示节点属于该树。
python
def check_for_tree_associativity(G: Graph, root_node: int, node_to_check: int) -> bool:
"""
检查节点是否属于以 root_node 为根的树
:param G: 图(Graph)
:param root_node: 树的根节点
:param node_to_check: 将要检查的节点
:return: 如果节点属于树,则返回 True
"""
node = node_to_check
parent = G.parent[node]
while parent is not None:
node = G.parent[node]
parent = G.parent[node]
return node == root_node
(11)下面的test_select_branch函数是一个测试函数,用于测试select_branch函数的功能。该函数首先创建了一个地图和一个图,然后使用 RRT_star_FN(是 RRT* 的一种扩展或变种实现)算法构建了搜索树,并返回了一条从起点到终点的最佳路径。接下来,它在生成的路径上调用了select_branch函数,并模拟了一些节点的移除和添加障碍物的情况。然后,它调用了valid_path函数,并最终通过reconnect函数重新连接了路径。
python
def test_select_branch():
map_width = 200
map_height = 200
start = (50, 50)
goal = (150, 150)
NODE_RADIUS = 5
step_length = 15
my_map = Map((map_width, map_height), start, goal, NODE_RADIUS)
my_map.generate_obstacles(obstacle_count=45, size=7)
G = Graph(start, goal, map_width, map_height)
iteration, best_path = RRT_star_FN(G, iter_num=500, map=my_map, step_length=step_length, radius=15,
node_radius=NODE_RADIUS, max_nodes=30, bias=0)
plot_graph(G, my_map.obstacles_c)
plt.pause(0.001)
last_node = best_path["path"][0]
id_to_remove = list(reversed(best_path["path"]))[4]
best_path["path"] = select_branch(G, id_to_remove, best_path["path"], my_map)
my_map.add_obstacles([[G.vertices[best_path["path"][6]], 7]])
my_map.add_obstacles([[G.vertices[best_path["path"][7]], 7]])
plt.figure()
plot_graph(G, my_map.obstacles_c)
plt.show()
id_separate_root = valid_path(G, best_path["path"], my_map, id_to_remove)
print(f"RRT_star algorithm stopped at iteration number: {iteration}")
plt.figure()
plot_graph(G, my_map.obstacles_c)
plt.show()
root_node = G.id_vertex[G.start]
ret_value = reconnect(G, best_path["path"], my_map, step_length*5, root_node, id_separate_root, last_node)
if ret_value is None:
regrow(G=G, map=my_map, step_length=step_length, radius=step_length, id_root=root_node,
id_separate_root=id_separate_root, last_node=last_node, bias=0.02)
(12)函数regrow用于重新生成路径的算法,在尝试重新连接根树和分离树的过程中,根据随机节点生成新节点,并尝试将其连接到最近的节点,直到达到最大迭代次数或重新连接成功。通过可选参数调整算法的行为,例如允许的最大边长、重连算法执行的圆形区域的半径以及是否对节点生成进行偏置。
python
def regrow(G: Graph, map: Map, step_length: float, radius: float, id_root: int,
id_separate_root: int, last_node: int, bias: float):
"""
Regrow算法用于重新生成路径,试图重新连接根树和分离树。
:param G: 图(Graph)
:param map: 地图(Map)
:param step_length: 允许的最大边长
:param radius: 重连算法执行的圆形区域的半径
:param id_root: 原始树的根节点的ID
:param id_separate_root: 分离树的根节点的ID
:param last_node: 原始路径的最后一个节点
:param bias: 0-1,朝向目标节点的偏置
:return: None
"""
separate_tree = [node for node in G.vertices if check_for_tree_associativity(G, id_separate_root, node)]
iter_num = 500
iter = 0
reconnected = False
n_of_nodes = len(G.vertices)
obstacles = map.obstacles_c
while iter < iter_num and not reconnected:
q_rand = G.random_node(bias=bias) # 生成随机节点
if map.is_occupied_c(q_rand): continue # 如果生成的随机节点在障碍物上,则继续
q_near, id_near = nearest_node(G, q_rand, obstacles, separate_tree) # 找到距离随机节点最近的节点
if q_near is None or q_rand == q_near: continue #如果最近的节点无效或与随机节点重合,则跳过
q_new = steer(q_rand, q_near, step_length) # 获取新节点的位置
if map.is_occupied_c(q_new): continue
id_new = G.add_vertex(q_new) # 获取新节点的ID
n_of_nodes += 1
cost_new_near = calc_distance(q_new, q_near) # 计算从q_new到q_near的距离
best_edge = (id_new, id_near, cost_new_near)
G.cost[id_new] = cost_new_near # 计算从最近节点到新节点的成本
G.parent[id_new] = id_near
choose_parent(G, q_new, id_new, best_edge, radius, obstacles, separate_tree)
G.add_edge(*best_edge)
iter += 1
for node in separate_tree: # 尝试重新连接根树和分离树
line = Line(q_new, G.vertices[node])
if through_obstacle(line, obstacles): continue
if calc_distance(q_new, G.vertices[node]) < step_length:
# 删除祖先节点(G, node)
reconnect_ancestors(G, node)
G.add_edge(node, id_new, calc_distance(q_new, G.vertices[node]))
reconnected = True
break
plt.pause(0.001)
plt.clf()
plot_graph(G, obstacles)
path = find_path(G, last_node, id_root)
plt.figure()
plot_graph(G, obstacles)
plot_path(G, path[0], "after regrow")
plt.show()
(13)函数 reconnect_ancestors 用于重新连接指定节点的祖先。它从给定节点开始,沿着其父节点向上遍历,逐级重新连接父节点和祖父节点,直到根节点。这样可以将断开的祖先节点重新连接到树中,确保树的完整性。
python
def reconnect_ancestors(G: Graph, id_node: int) -> None:
"""
重新连接指定节点的祖先
:param G: 图(Graph)
:param id_node: 要重新连接祖先的节点ID
:return: 无
"""
node = id_node
parent = G.parent[node]
if parent is None:
return
while G.parent[parent] is not None:
parent_parent = G.parent[parent]
G.parent[parent] = node
G.children[node].append(parent)
del G.children[parent][G.children[parent].index(node)]
node = parent
parent = parent_parent
G.parent[parent] = node
(14)函数 delete_ancestors 用于删除指定节点的所有祖先节点以及它们的子节点。它从给定节点开始,沿着其父节点向上遍历,逐级删除父节点和祖父节点,同时删除它们的子节点,直到根节点。
python
def delete_ancestors(G: Graph, id_node: int) -> None:
"""
删除指定节点的所有祖先节点及其子节点
:param G: 图(Graph)
:param id_node: 要删除祖先的节点ID
:return: 无
"""
node = id_node
parent = G.parent[node]
while parent is not None:
node = parent
remove_children(G, node, [])
parent = G.parent[node]
(15)函数intersection_circle用于检查直线段和圆之间是否相交。它接受一条直线和一个圆作为输入,并返回一个布尔值,指示它们是否相交。
python
def intersection_circle(line: Line, circle: list) -> bool:
"""
检查直线段和圆之间是否相交。
:param line: 直线
:param circle: 圆
:return: 如果相交返回True,否则返回False
"""
p1, p2 = line.p1, line.p2
if line.dir == float("inf"): # 当直线为垂直线时,计算delta无效
if abs(circle[0][0] - line.x_pos) <= circle[1]:
return True
else:
return False
delta, a, b = calc_delta(line, circle)
if delta < 0: # 无实数解
return False
ip1, ip2 = delta_solutions(delta, a, b) # 相交点1的x坐标;相交点2的x坐标
ip1[1] = line.calculate_y(ip1[0])
ip2[1] = line.calculate_y(ip2[0])
if is_between(ip1, p1, p2) or is_between(ip2, p1, p2):
return True
return False
(16)函数through_obstacle用于检查直线是否穿过障碍物,它接受一条直线和障碍物列表作为输入,并返回一个布尔值,指示直线是否与任何障碍物相交。
python
def through_obstacle(line: Line, obstacles: list) -> bool: # 目前仅适用于圆形障碍物
"""
检查直线是否穿过障碍物。
:param line: 要检查的直线
:param obstacles: 障碍物列表
:return: 如果与障碍物相撞返回True,否则返回False
"""
for obstacle in obstacles:
if intersection_circle(line, obstacle):
return True
return False
(17)函数delta_solutions用于计算二次方程的解,根据给定的 delta 和二次方程的系数,返回两个解。
python
def delta_solutions(delta: float, a: float, b: float) -> tuple:
"""
根据 delta 和二次方程的系数计算解。
:param delta: delta
:param a: x^2 的系数
:param b: x 的系数
:return: 两个解,x1 和 x2
"""
x1 = [(-b + delta ** 0.5) / (2 * a), None]
x2 = [(-b - delta ** 0.5) / (2 * a), None]
return x1, x2
(18)函数 calc_delta 用于计算直线和圆之间的交点的 delta 值,以及方程的系数,并返回这些值的元组。
python
def calc_delta(line: Line, circle: list) -> tuple:
"""
计算直线和圆之间的交点的 delta 值,以及方程的系数,并返回这些值的元组。
:param line: 直线
:param circle: 圆
:return: 包含判别式及方程系数的三元组 (delta, a, b),其中:
delta 为判别式(用于判断交点数量),
a 为二次方程中 x² 项的系数,
b 为二次方程中 x 项的系数
"""
x0 = circle[0][0] # 圆的中心 x 坐标
y0 = circle[0][1] # 圆的中心 y 坐标
r = circle[1] # 圆的半径
a = (1 + line.dir ** 2)
b = 2 * (-x0 + line.dir * line.const_term - line.dir * y0)
c = -r ** 2 + x0 ** 2 + y0 ** 2 - 2 * y0 * line.const_term + line.const_term ** 2
delta = b ** 2 - 4 * a * c
return delta, a, b
(19)函数 is_between 用于检查点 pb 是否位于线段 (p1, p2) 上。
python
def is_between(pb, p1, p2):
"""
检查点 pb 是否位于线段 (p1, p2) 上。
:param pb: 要检查的点
:param p1: 线段上的第一个点
:param p2: 线段上的第二个点
:return: 如果 pb 在 p1 和 p2 之间则返回 True,否则返回 False
"""
check1 = pb[0] > min(p1[0], p2[0])
check2 = pb[0] < max(p1[0], p2[0])
return check1 and check2
(20)函数plot_graph用于绘制图形及其上的节点、边和障碍物。给定图形对象和障碍物列表,它会绘制图形的节点、起点、终点以及图形之间的边,并在图形上添加圆形障碍物。
python
def plot_graph(graph: Graph, obstacles: list):
"""
绘制图形和障碍物。
:param graph: 要绘制的图形
:param obstacles: 障碍物列表
"""
xes = [pos[0] for id, pos in graph.vertices.items()]
yes = [pos[1] for id, pos in graph.vertices.items()]
plt.scatter(xes, yes, c='gray') # 绘制节点
plt.scatter(graph.start[0], graph.start[1], c='#49ab1f', s=50) # 绘制起点
plt.scatter(graph.goal[0], graph.goal[1], c='red', s=50) # 绘制目标点
edges = [(graph.vertices[id_ver], graph.vertices[child]) for pos_ver, id_ver in graph.id_vertex.items()
for child in graph.children[id_ver]]
for edge in edges:
plt.plot([edge[0][0], edge[1][0]], [edge[0][1], edge[1][1]], c='black', alpha=0.5) # 绘制边
# 绘制障碍物
plt.gca().set_aspect('equal', adjustable='box')
for obstacle in obstacles:
circle = plt.Circle(obstacle[0], obstacle[1], color='black')
plt.gca().add_patch(circle)
plt.xlim(0, graph.width)
plt.ylim(0, graph.height)
(21)函数nearest_node_kdtree用于查找最接近输入节点的节点,并检查是否穿过障碍物。它采用了 KD 树数据结构来快速搜索最近的节点。给定图形对象、节点位置、障碍物列表以及可选的分离树节点列表和 KD 树,函数返回与输入节点最接近且不穿过障碍物的节点的位置和ID。
python
def nearest_node_kdtree(G: Graph, vertex: tuple, obstacles: list, separate_tree_nodes: list = (), kdtree: cKDTree = None):
"""
检查距离输入节点最近的节点,检查是否穿过障碍物。
:param G: 图(Graph)
:param vertex: 要查找其邻居的节点的位置
:param obstacles: 障碍物列表
:param separate_tree_nodes: 分离树中的节点列表
:param kdtree: 除新节点外所有节点创建的KD树
:return: new_vertex, new_id
"""
try:
id = G.id_vertex[vertex]
return np.array(vertex), id
except KeyError:
closest_id = None
closest_pos = None
nn = 1
while True:
d, i = kdtree.query(vertex, k=nn, workers=-1)
if nn == 1:
closest_pos = kdtree.data[i]
else:
closest_pos = kdtree.data[i[-1]]
closest_id = G.id_vertex[closest_pos[0], closest_pos[1]]
line = Line(vertex, closest_pos)
nn += 1
if not through_obstacle(line, obstacles):
break
elif nn > len(G.vertices):
closest_pos = np.array(vertex)
closest_id = None
break
return closest_pos, closest_id
(22)函数nearest_node用于查找给定节点最近的图中的节点,同时检查是否穿过障碍物。如果输入节点已经是图中的节点之一,则直接返回该节点的位置和ID。否则,它会遍历图中的每个节点,排除分离树中的节点,并计算到输入节点的距离。然后,它将返回距离最近的节点的位置和ID。
python
def nearest_node(G: Graph, vertex: tuple, obstacles: list, separate_tree_nodes: list = ()):
"""
检查距离输入节点最近的节点,检查是否穿过障碍物。
:param G: 图(Graph)
:param vertex: 要查找其邻居的节点的位置
:param obstacles: 障碍物列表
:param separate_tree_nodes: 分离树中的节点列表
:return: new_vertex, new_id
"""
try:
id = G.id_vertex[vertex]
return vertex, id
except KeyError:
min_distance = float("inf")
new_id = None
new_vertex = None
for ver_id, ver in G.vertices.items():
if ver_id in separate_tree_nodes: continue
line = Line(ver, vertex)
if through_obstacle(line, obstacles): continue
distance = calc_distance(ver, vertex)
if distance < min_distance:
min_distance = distance
new_id = ver_id
new_vertex = ver
return new_vertex, new_id
(23)函数steer用于确定从一个给定点(父节点)到另一个给定点(目标节点)的方向,并返回新节点的位置,该位置位于给定长度的边上。如果两个点之间的距离大于最大允许长度,则返回的新节点将在指定方向上与父节点相距最大长度;否则,返回目标节点的位置。
python
def steer(to_vertex: tuple, from_vertex: tuple, max_length: float) -> tuple:
"""
返回新节点的位置。从顶点到目标顶点的方向,给定长度。
:param to_vertex: 标记方向的顶点的位置
:param from_vertex: 父顶点的位置
:param max_length: 两个节点之间允许的最大边长
:return: 新节点的位置
"""
distance = calc_distance(to_vertex, from_vertex)
x_vect_norm = (to_vertex[0] - from_vertex[0]) / distance
y_vect_norm = (to_vertex[1] - from_vertex[1]) / distance
x_pos = from_vertex[0] + x_vect_norm * max_length
y_pos = from_vertex[1] + y_vect_norm * max_length
if distance > max_length:
return x_pos, y_pos
return to_vertex
(24)函数check_solution用于检查是否已找到解决方案,即检查新节点是否足够接近目标节点,以判断是否达到了目标。
python
def check_solution(G: Graph, q_new: tuple, node_radius: int) -> bool:
"""
检查是否已找到解决方案(节点是否足够接近目标节点)。
:param G: 图(Graph)
:param q_new: 要检查的节点
:param node_radius: 节点的半径
:return: 如果找到解决方案,则返回 True,否则返回 False
"""
dist_to_goal = calc_distance(q_new, G.goal) # 检查是否到达目标点
if dist_to_goal < 2 * node_radius:
return True
return False
(25)函数plot_path用于绘制路径,参数分别一个图形对象 G,一个路径的节点ID列表 path,可选的标题字符串 title 和路径的成本 cost。函数plot_path通过连接路径中相邻节点的线段来绘制路径,并在标题中显示路径的成本。
python
def plot_path(G: Graph, path: list, title: str = "", cost: float = float("inf")):
"""
绘制路径。
:param G: 图(Graph)
:param path: 路径中节点的ID列表
:param title: 图的标题
:param cost: 路径的成本
"""
prev_node = G.goal
for point in path:
plt.plot((prev_node[0], G.vertices[point][0]), (prev_node[1], G.vertices[point][1]), c='#057af7', linewidth=2)
prev_node = G.vertices[point]
plt.title(title + f" cost: {round(cost, 2)}")
(26)函数find_path用于从图中的起始节点找到路径,直到到达指定的根节点。它返回从起始节点到根节点的路径列表以及路径的总成本。
python
def find_path(G: Graph, from_node: int, root_node: int) -> tuple:
"""
从起始节点找到路径。
:param G: 图
:param from_node: 起始节点
:param root_node: 根节点
:return: 路径,成本
"""
path = []
node = from_node
cost = 0
try:
while node != root_node:
path.append(node)
cost += G.cost[node]
node = G.parent[node]
path.append(root_node)
except Exception:
pass
return path, cost
(27)函数forced_removal用于从图中删除一个随机的无子节点的节点,它分别接受一个图对象(Graph)、不会被删除的节点的ID以及路径中的节点列表作为输入参数,并返回被删除的节点的ID。
python
def forced_removal(G: Graph, id_new: int, path: list) -> int:
"""
从图中删除一个随机的无子节点的节点。
:param G: 图(Graph)
:param id_new: 不会被删除的节点的ID
:param path: 节点列表中的路径
:return: 被删除的节点的ID
"""
id_last_in_path = -1
if path:
id_last_in_path = path[0]
childless_nodes = [node for node, children in G.children.items() if len(children) == 0] # and node != id_new
id_ver = random.choice(childless_nodes)
while id_ver == id_new or id_ver == id_last_in_path:
id_ver = random.choice(childless_nodes)
G.remove_vertex(id_ver)
return id_ver
(28)函数choose_parent_kdtree用于在给定的搜索半径内查找最优的父节点,以使从起始节点到新节点的成本最小化。它通过在KD树中查找半径内的所有节点,并计算它们到新节点的距离来实现此目的。然后,它检查每个候选节点是否通过障碍物,并比较其到新节点的成本是否比当前最佳边的成本更低。
python
def choose_parent_kdtree(G: Graph, q_new: tuple, id_new: int, best_edge: tuple,
radius: float, obstacles: list, separate_tree_nodes: list = (), kdtree: cKDTree = None) -> tuple:
"""
在搜索半径范围内,选择一个能以最小代价连接到新节点的父节点
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param best_edge: 到目前为止最佳边
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
:param separate_tree_nodes: 在分离树中的节点ID列表
:param kdtree: 除新节点之外的所有节点创建的kdtree
:return: 最佳节点的ID
"""
i = kdtree.query_ball_point(q_new, r=radius, workers=-1)
in_radius_pos = kdtree.data[i] #通过 KDTree 搜索得到的半径邻域内的点的坐标(位置)
#根据坐标 (pos[0], pos[1]),从 G.id_vertex 映射获取半径邻域内点对应的顶点 ID
in_radius_ids = [G.id_vertex[pos[0], pos[1]] for pos in in_radius_pos]
costs = np.linalg.norm(in_radius_pos - q_new, axis=1)
# new_costs = [G.get_cost(id_in_radius) + costs[] for id_in_radius in in_radius_ids]
for id_ver, vertex, cost in zip(in_radius_ids, in_radius_pos, costs):
line = Line(vertex, q_new)
if through_obstacle(line, obstacles): continue
if G.get_cost(id_new) > G.get_cost(id_ver) + cost:
G.cost[id_new] = cost
best_edge = (id_new, id_ver, cost)
return best_edge
(29)函数choose_parent用于在给定的搜索半径内查找最优的父节点,以使从起始节点到新节点的成本最小化。它通过迭代图中的所有节点来实现此目的,并计算每个节点到新节点的距离。然后,它检查每个候选节点是否通过障碍物,并比较其到新节点的成本是否比当前最佳边的成本更低。
python
def choose_parent(G: Graph, q_new: tuple, id_new: int, best_edge: tuple,
radius: float, obstacles: list, separate_tree_nodes: list = ()) -> tuple:
"""
寻找到起始节点成本最优的节点。
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param best_edge: 到目前为止最优的边
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
:param separate_tree_nodes: 分离树中节点的ID列表
:return: 最优节点的ID
"""
for id_ver, vertex in G.vertices.items(): # 遍历所有顶点
if id_ver == id_new: continue
distance_new_vert = calc_distance(q_new, vertex) # 计算新节点到顶点节点的距离
if round(distance_new_vert, 3) > radius: continue # 如果距离大于搜索半径,则继续
line = Line(vertex, q_new) # 创建从新节点到顶点的直线对象
if through_obstacle(line, obstacles): continue # 如果直线穿过障碍物,则继续
if G.get_cost(id_new) > G.get_cost(id_ver) + distance_new_vert: # 如果从新节点到顶点的成本小于当前成本,则重置顶点到新节点的成本
G.cost[id_new] = distance_new_vert
best_edge = (id_new, id_ver, distance_new_vert)
return best_edge
(30)函数rewire_kdtree实现了RRT_STAR算法的重连过程,用于更新图中节点的连接关系和成本。它根据给定的新节点位置和搜索半径,在搜索范围内查找节点,并检查是否可以通过将这些节点重新连接到新节点来降低其成本。如果发现可以通过重新连接来降低节点的成本,则更新图中相应节点的父节点和子节点,并更新其成本。
python
def rewire_kdtree(G: Graph, q_new: tuple, id_new: int, radius: float, obstacles: list, kdtree: cKDTree = None):
"""
RRT_STAR算法的重连过程。
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
:param kdtree: 除新节点外的所有节点构建的kdtree
"""
i = kdtree.query_ball_point(q_new, r=radius, workers=-1)
# 计算新节点与搜索半径内节点之间的距离
in_radius_pos = kdtree.data[i] # 搜索半径内的点的位置
in_radius_ids = [G.id_vertex[pos[0], pos[1]] for pos in in_radius_pos] # 搜索半径内的点的ID列表
costs = np.linalg.norm(in_radius_pos - q_new, axis=1)
# new_costs = [G.get_cost(id_in_radius) + costs[] for id_in_radius in in_radius_ids]
for id_ver, vertex, cost in zip(in_radius_ids, in_radius_pos, costs):
line = Line(vertex, q_new)
if through_obstacle(line, obstacles): continue
if G.get_cost(id_ver) > G.get_cost(id_new) + cost:
parent = G.parent[id_ver] # 重连节点的父节点
del G.children[parent][G.children[parent].index(id_ver)] # 从其父节点的子节点列表中删除重连节点
G.parent[id_ver] = id_new # 将重连节点的父节点设置为新节点
G.children[id_new].append(id_ver) # 将重连节点添加到新节点的子节点列表中
G.cost[id_ver] = cost
(31)函数rewire实现了RRT_STAR算法的重连过程,用于更新图中节点的连接关系和成本。它遍历所有的节点,排除起始节点和新节点,并在指定的搜索半径内查找节点。对于在搜索范围内的每个节点,它计算新节点与该节点之间的距离,并检查是否可以通过将这些节点重新连接到新节点来降低其成本。如果发现可以通过重新连接来降低节点的成本,则更新图中相应节点的父节点和子节点,并更新其成本。
python
def rewire(G: Graph, q_new: tuple, id_new: int, radius: float, obstacles: list):
"""
RRT_STAR算法的重连过程。
:param G: 图(Graph)
:param q_new: 新节点的位置
:param id_new: 新节点的ID
:param radius: 搜索区域的半径
:param obstacles: 障碍物列表
"""
for id_ver, vertex in G.vertices.items():
if id_ver == G.id_vertex[G.start]: continue
if id_ver == id_new: continue
distance_new_vert = calc_distance(q_new, vertex)
if distance_new_vert > radius: continue
line = Line(vertex, q_new)
if through_obstacle(line, obstacles): continue
if G.get_cost(id_ver) > G.get_cost(id_new) + distance_new_vert:
parent = G.parent[id_ver] # 重连节点的父节点
del G.children[parent][G.children[parent].index(id_ver)] # 从父节点的子节点中删除重连的节点
G.parent[id_ver] = id_new # 将重连节点的父节点设置为新节点
G.children[id_new].append(id_ver) # 将重连节点添加到新节点的子节点中
G.cost[id_ver] = distance_new_vert
(33)函数get_distance_dict用于计算给定节点与图中其他节点之间的距离,并返回以节点 ID 为
python
键,距离为值的字典
def get_distance_dict(G: Graph, node_to_check: int, indeces_to_check: list[int]) -> dict:
pos = G.vertices[node_to_check]
tree_points_list = [vertex for id_ver, vertex in G.vertices.items()]
tree_points = np.array(tree_points_list)
new_point = np.array(pos).reshape(-1, 2)
x2 = np.sum(tree_points ** 2, axis=1).reshape(-1, 1)
y2 = np.sum(new_point ** 2, axis=1).reshape(-1, 1)
xy = 2 * np.matmul(tree_points, new_point.T)
dists = np.sqrt(x2 - xy + y2.T)
distances = {id_ver: id_and_cost[0] for id_ver, id_and_cost in zip(G.vertices, dists)}
return distances
(34)函数calc_distance用于计算两点之间的欧几里得距离。它接受表示两点坐标的两个元组作为输入,并返回两点之间的距离作为浮点数值。
python
def calc_distance(p1: tuple, p2: tuple) -> float:
"""
计算两个点之间的距离。
:param p1: 点 1
:param p2: 点 2
:return: 点之间的距离
"""
return np.linalg.norm(np.array(p1) - np.array(p2))