逐行讲解python实现A*路径规划

目录

搜索步骤

A*路径规划是一种广度优先搜索算法,需要在栅格地图上进行搜索。其主要搜索步骤如下:

  1. 得到栅格地图,确定起点和终点位置;
  2. 计算起点周围点的代价 ,放到开集合中,可以周围4个点(四连通),也可以8个点(八连通);
  3. 开集合 中提取代价 最小的点,将其放到闭集合 中,并计算这个点周围点的代价 ,此时要判断周围点是否已经在开集合闭集合 中,并做特殊处理(后面详细说明),若不在两个集合中,则将其放到开集合中;
  4. 循环执行步骤3,直到弹出的点是终点。

关键点

开集合和闭集合

  • 开集合:待走的位置。相当于候选点集合,每到一个点,将周围点放入候选集合中,往后有可能往候选点去走。
  • 闭集合:已经走过的位置,不会再被考虑。

复杂度优化

  1. 上述第三步中 "从开集合中提取代价最小的点 " ,常规方法会循环遍历开集合,找到代价值最小的索引来提取,但循环结构太耗费时间。
    改进:用小顶堆 来解决(一种特殊的数据结构,若不了解可自行百度),小顶堆可以一直保持索引0的点为代价最小的点,它的插入和自更新时间复杂度为O(nlogn),相对循环更快。

    python 复制代码
    import heapq
    
    # 创建一个列表用来存储
    open_set = [15]
    # 往列表中压入一个新的数,并自行调整小顶堆的顺序
    heapq.heappush(open_set, 10)
    heapq.heappush(open_set, 20)
    heapq.heappush(open_set, 30)
    heapq.heappush(open_set, 14)
    # 弹出堆顶,永远是最小的数
    a = heapq.heappop(open_set)
    print(a)
  2. 上述第三步中 "判断周围点是否已经在开集合或闭集合中 ",常规方法会遍历开集合或闭集合,判断该点是否在其中,同样循环太耗费时间。
    改进:在创建地图的时候,地图上的每一个点都是一个对象,对应的类中新创建一个属性来记录当前点是否在开集合或闭集合中。用空间换时间,每次判断可直接索引查找,复杂度降维O(1)。

    python 复制代码
    # 地图上的每一个点都是一个Point对象,用于记录该点的类别、代价等信息
    class Point:
        def __init__(self, x=0, y=0):
            self.x = x
            self.y = y
            self.val = 0  # 0代表可通行,1代表障碍物
            self.cost_g = 0
            self.cost_h = 0
            self.cost_f = 0
            self.parent = None  # 父节点
            self.is_open = 0  # 0:不在开集合  1:在开集合  -1:在闭集合
    
        # 用于heapq小顶堆的比较
        def __lt__(self, other):
            return self.cost_f < other.cost_f

    注:heapq操作的列表里的元素默认都是数值,可以直接比较大小的,但现在存储的是类对象,所以需要额外写__lt__方法,用于比较大小。
    注:父节点,把当前点设为周围扩散点的父节点,用于标识该点来源于哪里,用于回溯完整路径。

代价

代价,即走到某个位置需要花费的成本。代价越小,这条路径就越好。

每个点的代价分为3种:

  1. g 代价:当前点到起点的代价,等于当前点与父节点之间的距离+父节点的 g 代价;
  2. h 代价:当前点到终点的代价,等于当前点与终点之间的距离;
  3. f 代价:当前点总代价,等于g 代价 + h 代价;

父节点替换

当前点在扩散时需要将周围的8个点都放入开集合中(假设8连通),这也意味着某一个点可能由周围8个点扩散过来。

那这个点到底归谁呢(决定它由谁扩散而来),当然是谁代价小就选谁。

对于一个点,它的h值是不变的,永远是与终点的距离,但是它的g值是不一样的,g值跟走过的路径有关,走的越长代价越大,所以选择走的最少的点作为自己的父节点。

所以具体实现步骤如下:

  1. 对于扩散到的点,首先判断是否有效(越出边界或障碍物),以及是否已经在闭集合中(已经走过的就不用再走一遍了)
  2. 若通过上一步,先将来源点(谁扩散来的)作为父节点,去计算g值(g值需要依靠父节点g值计算)
  3. 如果这个点不在开集合中,都好说,直接加入开集合就好
  4. 如果在开集合中,说明之前有别的点扩散到这过,比较这两个来源点到这的g值大小,选择小的那个作为自己的父节点。

下图为增加这个操作和不增加的对比,左图如果不增加,它会一直保留第一次扩散过来的点作为自己的父节点,右图切换父节点后就会选择更近的路。

该部分代码如下

python 复制代码
    def diffusion_point(self, x, y, parent):
        # 无效点或者在闭集合中,跳过
        if not self.is_valid_point(x, y) or self.map.map[x][y].is_open == -1:
            return
        p = self.map.map[x][y]
        pre_parent = p.parent
        p.parent = parent
        # 先计算出当前点的总代价
        cost_g = self.g_cost(p)
        cost_h = self.h_cost(p)
        cost_f = cost_g + cost_h
        # 如果在开集合中,判断当前点和开集合中哪个点代价小,换成小的,相同x,y的点h值相同,g值不一定相同
        if p.is_open == 1:
            if cost_f < p.cost_f:
                # 如果从当前parent遍历过来的代价更小,替换成当前的代价和父节点
                p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
            else:
                # 如果从之前父节点遍历过来的代价更小,保持之前的代价和父节点
                p.parent = pre_parent
        else:
            # 如果不在开集合中,说明之间没遍历过,直接加到开集合里就好
            p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
            heapq.heappush(self.open_set, p)
            p.is_open = 1

距离

若使用不同距离,可在完整代码中替换对应部分

  1. 欧式距离(欧几里得距离):两个点之间的直线距离

    python 复制代码
        def g_cost(self, p):
            '''
            计算 g 代价,当前点与父节点的距离 + 父节点的 g 代价(欧氏距离)
            :param p: 当前扩散的节点
            :return: p 的 g 代价
            '''
            x_dis = abs(p.parent.x - p.x)
            y_dis = abs(p.parent.y - p.y)
            return np.sqrt(x_dis ** 2 + y_dis ** 2) + p.parent.cost_g
    
        def h_cost(self, p):
            '''
            计算 h 代价,当前点与终点之间的距离(欧氏距离)
            :param p: 当前扩散的节点
            :return: p 的 h 代价
            '''
            x_dis = abs(self.end_point.x - p.x)
            y_dis = abs(self.end_point.y - p.y)
            return np.sqrt(x_dis ** 2 + y_dis ** 2)
  2. 曼哈顿距离:计算x轴和y轴之差的绝对值,适合在四连通时使用(只上下左右走),计算简单

    python 复制代码
        def g_cost(self, p):
            '''
            计算 g 代价,当前点与父节点的距离 + 父节点的 g 代价(曼哈顿距离)
            :param p: 当前扩散的节点
            :return: p 的 g 代价
            '''
            x_dis = abs(p.parent.x - p.x)
            y_dis = abs(p.parent.y - p.y)
            return x_dis + y_dis + p.parent.cost_g
    
        def h_cost(self, p):
            '''
            计算 h 代价,当前点与终点之间的距离(曼哈顿距离)
            :param p: 当前扩散的节点
            :return: p 的 h 代价
            '''
            x_dis = abs(self.end_point.x - p.x)
            y_dis = abs(self.end_point.y - p.y)
            return x_dis + y_dis
  3. 对角距离(待续)

地图设置

创建一个地图类,具体的地图用二维列表存储,每个元素都是Point对象。set_obstacle方法可手动设置障碍物。

python 复制代码
class Map:
    def __init__(self, map_size):
        self.map_size = map_size
        self.width = map_size[0]
        self.height = map_size[1]
        self.map = [[Point(x, y) for y in range(self.map_size[1])] for x in range(self.map_size[0])]

    # 手动设置障碍物,可多次调用设置地图
    # 由于地图方向不同,这里的topleft并不总是左上角,topleft代表x和y全都较小的点
    def set_obstacle(self, topleft, width, height):
        for x in range(topleft[0], topleft[0] + width):
            for y in range(topleft[1], topleft[1] + height):
                self.map[x][y].val = 1

完整代码

python 复制代码
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Author      : Cao Zejun
# @Time        : 2024/4/7 17:36
# @File        : Astar_blog.py
# @Software    : Pycharm
# @description : 用于CSDN上的Astar算法演示

import time
import numpy as np
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import heapq


# 地图上的每一个点都是一个Point对象,用于记录该点的类别、代价等信息
class Point:
    def __init__(self, x=0, y=0):
        self.x = x
        self.y = y
        self.val = 0  # 0代表可通行,1代表障碍物
        self.cost_g = 0  # 三个代价
        self.cost_h = 0
        self.cost_f = 0
        self.parent = None  # 父节点
        self.is_open = 0  # 0:不在开集合  1:在开集合  -1:在闭集合

    # 用于heapq小顶堆的比较
    def __lt__(self, other):
        return self.cost_f < other.cost_f


class Map:
    def __init__(self, map_size):
        self.map_size = map_size
        self.width = map_size[0]
        self.height = map_size[1]
        self.map = [[Point(x, y) for y in range(self.map_size[1])] for x in range(self.map_size[0])]

    # 手动设置障碍物,可多次调用设置地图
    # 由于地图方向不同,这里的topleft并不总是左上角,topleft代表x和y全都较小的点
    def set_obstacle(self, topleft, width, height):
        for x in range(topleft[0], topleft[0] + width):
            for y in range(topleft[1], topleft[1] + height):
                self.map[x][y].val = 1


class AStar:
    def __init__(self, map, start_point, end_point, connect_num=8, ax=None, print_diffusion_point=False):
        self.map: Map = map
        self.start_point = start_point
        self.end_point = end_point
        self.open_set = [self.start_point]  # 开集合,先放入起点,从起点开始遍历

        self.start_point.is_open = 1  #
        self.connect_num = connect_num  # 连通数,目前支持4连通或8连通
        self.diffuse_dir = [(1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (-1, 1), (1, -1), (-1, -1)]  # 遍历的8个方向,只需取出元组,加到x和y上就可以

    def g_cost(self, p):
        '''
        计算 g 代价,当前点与父节点的距离 + 父节点的 g 代价(欧氏距离)
        :param p: 当前扩散的节点
        :return: p 的 g 代价
        '''
        x_dis = abs(p.parent.x - p.x)
        y_dis = abs(p.parent.y - p.y)
        return np.sqrt(x_dis ** 2 + y_dis ** 2) + p.parent.cost_g

    def h_cost(self, p):
        '''
        计算 h 代价,当前点与终点之间的距离(欧氏距离)
        :param p: 当前扩散的节点
        :return: p 的 h 代价
        '''
        x_dis = abs(self.end_point.x - p.x)
        y_dis = abs(self.end_point.y - p.y)
        return np.sqrt(x_dis ** 2 + y_dis ** 2)

    def is_valid_point(self, x, y):
        # 无效点:超出地图边界或为障碍物
        if x < 0 or x >= self.map.width:
            return False
        if y < 0 or y >= self.map.height:
            return False
        return self.map.map[x][y].val == 0

    def search(self):
        self.start_time = time.time()  # 用于记录搜索时间
        p = self.start_point
        # p 为当前遍历节点,等于终点停下
        while not (p == self.end_point):
            # 弹出代价最小的开集合点,若开集合为空,说明没有路径
            try:
                p = heapq.heappop(self.open_set)
            except:
                raise 'No path found, algorithm failed!!!'

            p.is_open = -1

            # 遍历周围点
            for i in range(self.connect_num):
                dir_x, dir_y = self.diffuse_dir[i]
                self.diffusion_point(p.x + dir_x, p.y + dir_y, p)
        return self.build_path(p)  # p = self.end_point

    def diffusion_point(self, x, y, parent):
        # 无效点或者在闭集合中,跳过
        if not self.is_valid_point(x, y) or self.map.map[x][y].is_open == -1:
            return
        p = self.map.map[x][y]
        pre_parent = p.parent
        p.parent = parent
        # 先计算出当前点的总代价
        cost_g = self.g_cost(p)
        cost_h = self.h_cost(p)
        cost_f = cost_g + cost_h
        # 如果在开集合中,判断当前点和开集合中哪个点代价小,换成小的,相同x,y的点h值相同,g值不一定相同
        if p.is_open == 1:
            if cost_f < p.cost_f:
                # 如果从当前parent遍历过来的代价更小,替换成当前的代价和父节点
                p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
            else:
                # 如果从之前父节点遍历过来的代价更小,保持之前的代价和父节点
                p.parent = pre_parent
        else:
            # 如果不在开集合中,说明之间没遍历过,直接加到开集合里就好
            p.cost_g, p.cost_h, p.cost_f = cost_g, cost_h, cost_f
            heapq.heappush(self.open_set, p)
            p.is_open = 1

    def build_path(self, p):
        print('search time: ', time.time() - self.start_time, ' seconds')
        # 回溯完整路径
        path = []
        while p != self.start_point:
            path.append(p)
            p = p.parent
        print('search time: ', time.time() - self.start_time, ' seconds')
        # 打印开集合、闭集合的数量
        print('open set count: ', len(self.open_set))
        close_count = 0
        for x in range(self.map.width):
            for y in range(self.map.height):
                close_count += 1 if self.map.map[x][y] == -1 else 0
        print('close set count: ', close_count)
        print('total count: ', close_count + len(self.open_set))
        # path = path[::-1]  # path为终点到起点的顺序,可使用该语句翻转
        return path

if __name__ == '__main__':
    map = Map((50, 50))

    # 用于显示plt图
    ax = plt.gca()
    ax.set_xlim([0, map.width])
    ax.set_ylim([0, map.height])
    plt.tight_layout()

    # 设置障碍物
    map.set_obstacle([10, 27], 20, 4)
    # 将障碍物显示到plt上
    ax.add_patch(Rectangle([10, 27], width=20, height=4, color='gray'))

    # 设置起始点和终点,并创建astar对象
    start_point = map.map[5][5]
    end_point = map.map[20][40]
    astar = AStar(map, start_point, end_point)
    path = astar.search()

    # 搜索之后打印,功能拓展时不影响搜索时间
    # 打印开集合点和闭集合点,可视化扩散点数量
    for x in range(map.width):
        for y in range(map.height):
            if map.map[x][y].is_open == -1:
                ax.add_patch(Rectangle([x, y], width=1, height=1, color='green'))
            if map.map[x][y].is_open == 1:
                ax.add_patch(Rectangle([x, y], width=1, height=1, color='blue'))
    # 可视化起点到终点完整路径
    for p in path:
        ax.add_patch(Rectangle([p.x, p.y], width=1, height=1, color='red'))

    # plt.savefig('./output/tmp.jpg')  # 可选择将其保存为本地图片
    plt.show()

备注

若有错误,可在评论区指出,我会及时修改

相关推荐
阿华的代码王国43 分钟前
【JavaEE】——文件IO的应用
开发语言·python
电饭叔1 小时前
《python语言程序设计》2018版第8章19题几何Rectangle2D类(下)-头疼的几何和数学
开发语言·python
程序猿小D2 小时前
第二百六十七节 JPA教程 - JPA查询AND条件示例
java·开发语言·前端·数据库·windows·python·jpa
杰哥在此3 小时前
Python知识点:如何使用Multiprocessing进行并行任务管理
linux·开发语言·python·面试·编程
zaim15 小时前
计算机的错误计算(一百一十四)
java·c++·python·rust·go·c·多项式
PythonFun9 小时前
Python批量下载PPT模块并实现自动解压
开发语言·python·powerpoint
炼丹师小米10 小时前
Ubuntu24.04.1系统下VideoMamba环境配置
python·环境配置·videomamba
GFCGUO10 小时前
ubuntu18.04运行OpenPCDet出现的问题
linux·python·学习·ubuntu·conda·pip
985小水博一枚呀11 小时前
【深度学习基础模型】神经图灵机(Neural Turing Machines, NTM)详细理解并附实现代码。
人工智能·python·rnn·深度学习·lstm·ntm
萧鼎13 小时前
Python调试技巧:高效定位与修复问题
服务器·开发语言·python