目录
搜索步骤
A*路径规划是一种广度优先搜索算法,需要在栅格地图上进行搜索。其主要搜索步骤如下:
- 得到栅格地图,确定起点和终点位置;
- 计算起点周围点的代价 ,放到开集合中,可以周围4个点(四连通),也可以8个点(八连通);
- 从开集合 中提取代价 最小的点,将其放到闭集合 中,并计算这个点周围点的代价 ,此时要判断周围点是否已经在开集合 或闭集合 中,并做特殊处理(后面详细说明),若不在两个集合中,则将其放到开集合中;
- 循环执行步骤3,直到弹出的点是终点。
关键点
开集合和闭集合
- 开集合:待走的位置。相当于候选点集合,每到一个点,将周围点放入候选集合中,往后有可能往候选点去走。
- 闭集合:已经走过的位置,不会再被考虑。
复杂度优化
-
上述第三步中 "从开集合中提取代价最小的点 " ,常规方法会循环遍历开集合,找到代价值最小的索引来提取,但循环结构太耗费时间。
改进:用小顶堆 来解决(一种特殊的数据结构,若不了解可自行百度),小顶堆可以一直保持索引0的点为代价最小的点,它的插入和自更新时间复杂度为O(nlogn),相对循环更快。pythonimport heapq # 创建一个列表用来存储 open_set = [15] # 往列表中压入一个新的数,并自行调整小顶堆的顺序 heapq.heappush(open_set, 10) heapq.heappush(open_set, 20) heapq.heappush(open_set, 30) heapq.heappush(open_set, 14) # 弹出堆顶,永远是最小的数 a = heapq.heappop(open_set) print(a)
-
上述第三步中 "判断周围点是否已经在开集合或闭集合中 ",常规方法会遍历开集合或闭集合,判断该点是否在其中,同样循环太耗费时间。
改进:在创建地图的时候,地图上的每一个点都是一个对象,对应的类中新创建一个属性来记录当前点是否在开集合或闭集合中。用空间换时间,每次判断可直接索引查找,复杂度降维O(1)。python# 地图上的每一个点都是一个Point对象,用于记录该点的类别、代价等信息 class Point: def __init__(self, x=0, y=0): self.x = x self.y = y self.val = 0 # 0代表可通行,1代表障碍物 self.cost_g = 0 self.cost_h = 0 self.cost_f = 0 self.parent = None # 父节点 self.is_open = 0 # 0:不在开集合 1:在开集合 -1:在闭集合 # 用于heapq小顶堆的比较 def __lt__(self, other): return self.cost_f < other.cost_f
注:heapq操作的列表里的元素默认都是数值,可以直接比较大小的,但现在存储的是类对象,所以需要额外写__lt__方法,用于比较大小。
注:父节点,把当前点设为周围扩散点的父节点,用于标识该点来源于哪里,用于回溯完整路径。
代价
代价,即走到某个位置需要花费的成本。代价越小,这条路径就越好。
每个点的代价分为3种:
- g 代价:当前点到起点的代价,等于当前点与父节点之间的距离+父节点的 g 代价;
- h 代价:当前点到终点的代价,等于当前点与终点之间的距离;
- f 代价:当前点总代价,等于g 代价 + h 代价;
父节点替换
当前点在扩散时需要将周围的8个点都放入开集合中(假设8连通),这也意味着某一个点可能由周围8个点扩散过来。
那这个点到底归谁呢(决定它由谁扩散而来),当然是谁代价小就选谁。
对于一个点,它的h值是不变的,永远是与终点的距离,但是它的g值是不一样的,g值跟走过的路径有关,走的越长代价越大,所以选择走的最少的点作为自己的父节点。
所以具体实现步骤如下:
- 对于扩散到的点,首先判断是否有效(越出边界或障碍物),以及是否已经在闭集合中(已经走过的就不用再走一遍了)
- 若通过上一步,先将来源点(谁扩散来的)作为父节点,去计算g值(g值需要依靠父节点g值计算)
- 如果这个点不在开集合中,都好说,直接加入开集合就好
- 如果在开集合中,说明之前有别的点扩散到这过,比较这两个来源点到这的g值大小,选择小的那个作为自己的父节点。
下图为增加这个操作和不增加的对比,左图如果不增加,它会一直保留第一次扩散过来的点作为自己的父节点,右图切换父节点后就会选择更近的路。
该部分代码如下
python
def diffusion_point(self, x, y, parent):
# 无效点或者在闭集合中,跳过
if not self.is_valid_point(x, y) or self.map.map[x][y].is_open == -1:
return
p = self.map.map[x][y]
pre_parent = p.parent
p.parent = parent
# 先计算出当前点的总代价
cost_g = self.g_cost(p)
cost_h = self.h_cost(p)
cost_f = cost_g + cost_h
# 如果在开集合中,判断当前点和开集合中哪个点代价小,换成小的,相同x,y的点h值相同,g值不一定相同
if p.is_open == 1:
if cost_f < p.cost_f:
# 如果从当前parent遍历过来的代价更小,替换成当前的代价和父节点
p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
else:
# 如果从之前父节点遍历过来的代价更小,保持之前的代价和父节点
p.parent = pre_parent
else:
# 如果不在开集合中,说明之间没遍历过,直接加到开集合里就好
p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
heapq.heappush(self.open_set, p)
p.is_open = 1
距离
若使用不同距离,可在完整代码中替换对应部分
-
欧式距离(欧几里得距离):两个点之间的直线距离
pythondef g_cost(self, p): ''' 计算 g 代价,当前点与父节点的距离 + 父节点的 g 代价(欧氏距离) :param p: 当前扩散的节点 :return: p 的 g 代价 ''' x_dis = abs(p.parent.x - p.x) y_dis = abs(p.parent.y - p.y) return np.sqrt(x_dis ** 2 + y_dis ** 2) + p.parent.cost_g def h_cost(self, p): ''' 计算 h 代价,当前点与终点之间的距离(欧氏距离) :param p: 当前扩散的节点 :return: p 的 h 代价 ''' x_dis = abs(self.end_point.x - p.x) y_dis = abs(self.end_point.y - p.y) return np.sqrt(x_dis ** 2 + y_dis ** 2)
-
曼哈顿距离:计算x轴和y轴之差的绝对值,适合在四连通时使用(只上下左右走),计算简单
pythondef g_cost(self, p): ''' 计算 g 代价,当前点与父节点的距离 + 父节点的 g 代价(曼哈顿距离) :param p: 当前扩散的节点 :return: p 的 g 代价 ''' x_dis = abs(p.parent.x - p.x) y_dis = abs(p.parent.y - p.y) return x_dis + y_dis + p.parent.cost_g def h_cost(self, p): ''' 计算 h 代价,当前点与终点之间的距离(曼哈顿距离) :param p: 当前扩散的节点 :return: p 的 h 代价 ''' x_dis = abs(self.end_point.x - p.x) y_dis = abs(self.end_point.y - p.y) return x_dis + y_dis
-
对角距离(待续)
地图设置
创建一个地图类,具体的地图用二维列表存储,每个元素都是Point对象。set_obstacle方法可手动设置障碍物。
python
class Map:
def __init__(self, map_size):
self.map_size = map_size
self.width = map_size[0]
self.height = map_size[1]
self.map = [[Point(x, y) for y in range(self.map_size[1])] for x in range(self.map_size[0])]
# 手动设置障碍物,可多次调用设置地图
# 由于地图方向不同,这里的topleft并不总是左上角,topleft代表x和y全都较小的点
def set_obstacle(self, topleft, width, height):
for x in range(topleft[0], topleft[0] + width):
for y in range(topleft[1], topleft[1] + height):
self.map[x][y].val = 1
完整代码
python
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Author : Cao Zejun
# @Time : 2024/4/7 17:36
# @File : Astar_blog.py
# @Software : Pycharm
# @description : 用于CSDN上的Astar算法演示
import time
import numpy as np
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import heapq
# 地图上的每一个点都是一个Point对象,用于记录该点的类别、代价等信息
class Point:
def __init__(self, x=0, y=0):
self.x = x
self.y = y
self.val = 0 # 0代表可通行,1代表障碍物
self.cost_g = 0 # 三个代价
self.cost_h = 0
self.cost_f = 0
self.parent = None # 父节点
self.is_open = 0 # 0:不在开集合 1:在开集合 -1:在闭集合
# 用于heapq小顶堆的比较
def __lt__(self, other):
return self.cost_f < other.cost_f
class Map:
def __init__(self, map_size):
self.map_size = map_size
self.width = map_size[0]
self.height = map_size[1]
self.map = [[Point(x, y) for y in range(self.map_size[1])] for x in range(self.map_size[0])]
# 手动设置障碍物,可多次调用设置地图
# 由于地图方向不同,这里的topleft并不总是左上角,topleft代表x和y全都较小的点
def set_obstacle(self, topleft, width, height):
for x in range(topleft[0], topleft[0] + width):
for y in range(topleft[1], topleft[1] + height):
self.map[x][y].val = 1
class AStar:
def __init__(self, map, start_point, end_point, connect_num=8, ax=None, print_diffusion_point=False):
self.map: Map = map
self.start_point = start_point
self.end_point = end_point
self.open_set = [self.start_point] # 开集合,先放入起点,从起点开始遍历
self.start_point.is_open = 1 #
self.connect_num = connect_num # 连通数,目前支持4连通或8连通
self.diffuse_dir = [(1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (-1, 1), (1, -1), (-1, -1)] # 遍历的8个方向,只需取出元组,加到x和y上就可以
def g_cost(self, p):
'''
计算 g 代价,当前点与父节点的距离 + 父节点的 g 代价(欧氏距离)
:param p: 当前扩散的节点
:return: p 的 g 代价
'''
x_dis = abs(p.parent.x - p.x)
y_dis = abs(p.parent.y - p.y)
return np.sqrt(x_dis ** 2 + y_dis ** 2) + p.parent.cost_g
def h_cost(self, p):
'''
计算 h 代价,当前点与终点之间的距离(欧氏距离)
:param p: 当前扩散的节点
:return: p 的 h 代价
'''
x_dis = abs(self.end_point.x - p.x)
y_dis = abs(self.end_point.y - p.y)
return np.sqrt(x_dis ** 2 + y_dis ** 2)
def is_valid_point(self, x, y):
# 无效点:超出地图边界或为障碍物
if x < 0 or x >= self.map.width:
return False
if y < 0 or y >= self.map.height:
return False
return self.map.map[x][y].val == 0
def search(self):
self.start_time = time.time() # 用于记录搜索时间
p = self.start_point
# p 为当前遍历节点,等于终点停下
while not (p == self.end_point):
# 弹出代价最小的开集合点,若开集合为空,说明没有路径
try:
p = heapq.heappop(self.open_set)
except:
raise 'No path found, algorithm failed!!!'
p.is_open = -1
# 遍历周围点
for i in range(self.connect_num):
dir_x, dir_y = self.diffuse_dir[i]
self.diffusion_point(p.x + dir_x, p.y + dir_y, p)
return self.build_path(p) # p = self.end_point
def diffusion_point(self, x, y, parent):
# 无效点或者在闭集合中,跳过
if not self.is_valid_point(x, y) or self.map.map[x][y].is_open == -1:
return
p = self.map.map[x][y]
pre_parent = p.parent
p.parent = parent
# 先计算出当前点的总代价
cost_g = self.g_cost(p)
cost_h = self.h_cost(p)
cost_f = cost_g + cost_h
# 如果在开集合中,判断当前点和开集合中哪个点代价小,换成小的,相同x,y的点h值相同,g值不一定相同
if p.is_open == 1:
if cost_f < p.cost_f:
# 如果从当前parent遍历过来的代价更小,替换成当前的代价和父节点
p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
else:
# 如果从之前父节点遍历过来的代价更小,保持之前的代价和父节点
p.parent = pre_parent
else:
# 如果不在开集合中,说明之间没遍历过,直接加到开集合里就好
p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
heapq.heappush(self.open_set, p)
p.is_open = 1
def build_path(self, p):
print('search time: ', time.time() - self.start_time, ' seconds')
# 回溯完整路径
path = []
while p != self.start_point:
path.append(p)
p = p.parent
print('search time: ', time.time() - self.start_time, ' seconds')
# 打印开集合、闭集合的数量
print('open set count: ', len(self.open_set))
close_count = 0
for x in range(self.map.width):
for y in range(self.map.height):
close_count += 1 if self.map.map[x][y] == -1 else 0
print('close set count: ', close_count)
print('total count: ', close_count + len(self.open_set))
# path = path[::-1] # path为终点到起点的顺序,可使用该语句翻转
return path
if __name__ == '__main__':
map = Map((50, 50))
# 用于显示plt图
ax = plt.gca()
ax.set_xlim([0, map.width])
ax.set_ylim([0, map.height])
plt.tight_layout()
# 设置障碍物
map.set_obstacle([10, 27], 20, 4)
# 将障碍物显示到plt上
ax.add_patch(Rectangle([10, 27], width=20, height=4, color='gray'))
# 设置起始点和终点,并创建astar对象
start_point = map.map[5][5]
end_point = map.map[20][40]
astar = AStar(map, start_point, end_point)
path = astar.search()
# 搜索之后打印,功能拓展时不影响搜索时间
# 打印开集合点和闭集合点,可视化扩散点数量
for x in range(map.width):
for y in range(map.height):
if map.map[x][y].is_open == -1:
ax.add_patch(Rectangle([x, y], width=1, height=1, color='green'))
if map.map[x][y].is_open == 1:
ax.add_patch(Rectangle([x, y], width=1, height=1, color='blue'))
# 可视化起点到终点完整路径
for p in path:
ax.add_patch(Rectangle([p.x, p.y], width=1, height=1, color='red'))
# plt.savefig('./output/tmp.jpg') # 可选择将其保存为本地图片
plt.show()
备注
若有错误,可在评论区指出,我会及时修改