匈牙利算法【python,算法】

匈牙利算法的主要步骤如下:

  1. 记录原始矩阵为mat

  2. 对原始矩阵进行等效操作,操作方法如下:

    • 对于矩阵的每一行,找出最小值row_min_v,给这一行的每一个元素减去row_min_v
    • 对于矩阵的每一列,找出最小值col_min_v,给这一列的每一个元素减去col_min_v
  3. 主算法:

    记录矩阵的维度为dim,标记 0 元素的个数为zero_cnt

    如果zero_cnt<dim:

    • 对 0 元素矩阵进行划线,如果得到的划线行列总数小于dim,则需要调整矩阵。
    • 划线
      1. 标记不包含被标记的 0 元素的行,并在 non_marked_row 中存储行索引;
      2. 搜索 non_marked_row 元素,并找出相应列中是否有未标记的 0 元素;【找出未标记的独立 0 元素所在的行,加到 non_marked_row,打勾
      3. 将列索引存储在 marked_cols 中;【上述行中把独立 0 元素包含的列都 marked,打勾,这是之后要划的竖线】
      4. 比较存储在 marked_zero 和 marked_cols 中的列索引;【4、5步是找出(3)中划的列线包括的marked_0元素,把这行加到non_marked_row,打勾
      5. 如果存在一个匹配的列索引,那么相应的行索引就会被保存到non_marked_rows中;
      6. 接下来,不在 non_marked_row 中的行索引被保存在 marked_rows 中。
    • 调整矩阵:
      1. 在未划线的元素中,找到最小值。
      2. 对所有未划线的元素减去最小值
      3. 对划线交叉点的元素加上最小值

    如果zero_cnt=dim,在原矩阵中标记出算法选择的元素,即标记 0 元素的位置所对应的元素。

下面通过手撕代码实现了匈牙利算法,并与scipy库的算法进行对比,可以发现手动实现的算法与库函数实现是等效的。

python 复制代码
import copy
from pprint import pprint
from typing import List

import numpy as np
from scipy.optimize import linear_sum_assignment


def hungarian(mat: List):
	"""匈牙利算法

	步骤:
	1. 将矩阵转换为 numpy 数组
	2. 矩阵的等级转换
       - 对于每一行,找到最小值,然后将每个元素减去最小值
       - 对于每一列,找到最小值,然后将每个元素减去最小值
	3. 对矩阵进行划线处理,得到划线的行和列小标
	   - 将矩阵转换为 bool 类型的矩阵,True 表示 0 元素,False 表示其他元素。
	   - 循环划取 0 ,直到 bool 矩阵中没有 0 为止。
	     1. 查找矩阵中含有最少 0 的行,对该行中的第一个 0 所在的行列进行 False 处理,并将这个 0 的行列位置添加到 0 元素列表。
	     2. 通过 0 元素列表得到没有 0 元素的行,划取该行。
	     3. 找到这行中包含 0 的列,对列进行划线。
	     4. 找到这列中包划圈的 0 元素,并对这个元素所在的行划线。
	     5. 重复 3-4 步骤,直到不满足划去条件为止。
	4. 如果划线行列总数小于矩阵的维度,则按照下面的方法调整矩阵:
	   - 在未划线的元素中,找到最小值。
	   - 对所有未划线的元素减去最小值
	   - 对划线交叉点的元素加上最小值
	   执行完成后,跳转到步骤 3。
	5. 如果划线行列总数等于矩阵的维度,按照如下步骤计算结果:
	   - 在原矩阵中标记出最优匹配。
	   - 标记的同时计算最优价值。
	:param mat: 原始矩阵
	:return: 最优矩阵,最优条件下的代价值
	"""
	# 保留原始矩阵
	orig_mat = copy.deepcopy(mat)
	# 转化为 np 矩阵
	mat = np.array(mat)
	# 求矩阵进行等价处理
	reduce_mat = reduce_func(mat)
	dim = mat.shape[0]
	zero_count = 0
	while zero_count < dim:
		select_pos, marked_rows, marked_cols = mark_matrix(reduce_mat)
		zero_count += len(select_pos)
		if zero_count < dim:
			adjust_matrix(reduce_mat, marked_rows, marked_cols)
	else:
		cost, cost_mat = optimize_matrix(orig_mat, select_pos)
		pprint(f"total cost is {cost}!")
		pprint(cost_mat)
		return cost, cost_mat


def mark_matrix(mat):
	"""
	模拟划线过程,具体步骤为:
	# 把所有零元素全部标记
	# 计算最小画线次数,即 marked_rows 为划横线,marked_cols 为划竖线
		1)标记不包含被标记的 0 元素的行,并在 non_marked_row 中存储行索引;
		2)搜索 non_marked_row 元素,并找出相应列中是否有未标记的 0 元素;【找出未标记的独立 0 元素所在的行,加到 non_marked_row,*打勾*】
		3)将列索引存储在 marked_cols 中;【上述行中把独立 0 元素包含的列都 marked,*打勾*,这是之后要划的竖线】
		4)比较存储在 marked_zero 和 marked_cols 中的列索引;【4、5步是找出(3)中划的列线包括的marked_0元素,把这行加到non_marked_row,*打勾*】
		5)如果存在一个匹配的列索引,那么相应的行索引就会被保存到non_marked_rows中;
		6)接下来,不在 non_marked_row 中的行索引被保存在 marked_rows 中
	:param mat:
	:return: (marked_zero, marked_rows, marked_cols)
	【返回没有打勾的行,和打勾的列】
	"""

	# 原矩阵中元素为0的地方标记为True,其他都为False
	cur_mat = mat
	zero_bool_mat = (cur_mat == 0)
	zero_bool_mat_copy = zero_bool_mat.copy()

	# marked_zero 记录了标记0的位置,按顺序存储
	marked_zero = []
	# 模拟划线过程
	while True in zero_bool_mat_copy:
		# 每执行一次min_zero_row()函数
		# 就找到零元素最少的那一行,找到该行第一个零元素
		# 将这个零元素的行和列全部置为False
		# 直到所有零元素都被标记过
		min_zero_row(zero_bool_mat_copy, marked_zero)

	# 记录被标记过的行和列(也就是划过线的行和列)
	marked_zero_row = []
	marked_zero_col = []
	for i in range(len(marked_zero)):
		marked_zero_row.append(marked_zero[i][0])
		marked_zero_col.append(marked_zero[i][1])

	# 找到没被标记过的行(即没有独立 0 元素的行)
	non_marked_row = list(set(range(cur_mat.shape[0])) - set(marked_zero_row))

	marked_cols = []
	check_switch = True
	while check_switch:
		check_switch = False
		for i in range(len(non_marked_row)):
			row_array = zero_bool_mat[non_marked_row[i], :]
			for j in range(row_array.shape[0]):
				# 找到没被标记的行中,是否有没被标记的 0 元素(也就是被迫被划线经过的列)
				# 在没有独立 0 元素的行中,找到所含 0 元素的列,加入到 marked_cols 中
				if row_array[j] == True and j not in marked_cols:
					marked_cols.append(j)
					check_switch = True
		# 对所有 marked_cols 中,独立的 0 元素所在的行取出来加到 non_marked_row 中
		for row_num, col_num in marked_zero:
			# 前面标记的独立 0 元素出现在独立 0 元素所在的列上
			if col_num in marked_cols and row_num not in non_marked_row:
				non_marked_row.append(row_num)
				check_switch = True

	marked_rows = list(set(range(mat.shape[0])) - set(non_marked_row))
	# 最后划线最少的方式是把打勾的列和没打勾的行划出来
	return marked_zero, marked_rows, marked_cols


def min_zero_row(zero_mat, mark_zero):
	"""
		1)找到零元素最少的行,以及该行第一个零元素,记录其坐标(min_row[1], zero_index)
		2)将该元素的行和列全部赋为False
	:param zero_mat:  Bool矩阵
	:param mark_zero: 存储标记的0元素的list
	:return: 没有返回值,直接修改bool矩阵
	"""

	min_row = [99999, -1]

	# 找到零元素最少的行,记为min_row= [0元素个数, 行号]
	for row_num in range(zero_mat.shape[0]):
		if 0 < np.sum(zero_mat[row_num] == True) < min_row[0]:
			min_row = [np.sum(zero_mat[row_num] == True), row_num]

	# np.where()返回零元素最少的行中,第一个零元素的下标
	zero_index = np.where(zero_mat[min_row[1]] == True)[0][0]
	# 存储标记0的位置
	mark_zero.append((min_row[1], zero_index))
	# 该标记0元素的这一行和这一列全部置为False
	zero_mat[min_row[1], :] = False
	zero_mat[:, zero_index] = False


def adjust_matrix(cur_mat, cover_rows, cover_cols):
	"""
	对矩阵进行调整:具体做法为:
		1)找到未被标记的元素中的最小值
		2)未被标记的元素 - 最小值
		3)标记的行和列中相交的元素 + 最小值
	:param mat: 原先操作过的矩阵
	:param cover_rows:  标记的行
	:param cover_cols:  标记的列
	:return: 调整后的矩阵
	"""
	# 找到未被标记的行和列中的最小值
	non_zero_element = []

	# Find the minimum value
	for row in range(len(cur_mat)):
		if row not in cover_rows:
			for i in range(len(cur_mat[row])):
				if i not in cover_cols:
					non_zero_element.append(cur_mat[row][i])
	min_num = min(non_zero_element)

	# 未标记的元素 - 最小值
	for row in range(len(cur_mat)):
		if row not in cover_rows:
			for i in range(len(cur_mat[row])):
				if i not in cover_cols:
					cur_mat[row, i] -= min_num

	# 标记的行和列 相交的元素 + 最小值
	for row in range(len(cover_rows)):
		for col in range(len(cover_cols)):
			cur_mat[cover_rows[row], cover_cols[col]] = cur_mat[cover_rows[row], cover_cols[col]] + min_num


def optimize_matrix(cost_mat, select_pos):
	"""在原矩阵中标记最优选择,并计算最优价值"""
	optimizer_cost = 0
	for i, row in enumerate(cost_mat):
		for j, v in enumerate(row):
			if (i, j) in select_pos:
				optimizer_cost += v
				cost_mat[i][j] = f"{cost_mat[i][j]}T"
	return optimizer_cost, cost_mat


def reduce_func(cost_mat: np.ndarray) -> np.ndarray:
	"""行列统一减去最小的数

	:param mat: 原始矩阵
	:return: 修改之后的矩阵,它与原矩阵的最优解相同
	"""
	col_reduce = cost_mat - np.min(cost_mat, axis=1, keepdims=True)
	row_reduce = col_reduce - np.min(col_reduce, axis=0, keepdims=True)
	return row_reduce


def hungarian_by_third_lib(cost_mat):
	"""通过第三方库的匈牙利算法计算 """
	work_idx_ls, pokeman_idx_ls = linear_sum_assignment(cost_mat)
	cost = 0
	for work_idx, poken_idx in zip(work_idx_ls, pokeman_idx_ls):
		cost += cost_mat[work_idx][poken_idx]
		cost_mat[work_idx][poken_idx] = f"{cost_mat[work_idx][poken_idx]}T"
	pprint(f"total cost is {cost}!")
	pprint(cost_mat)


if __name__ == '__main__':
	mat = [[7, 6, 2, 11],
	       [6, 2, 1, 3],
	       [5, 6, 8, 9],
	       [6, 8, 5, 8]]
	hungarian(copy.deepcopy(mat))
	hungarian_by_third_lib(copy.deepcopy(mat))
相关推荐
全栈开发圈几秒前
新书速览|Java网络爬虫精解与实践
java·开发语言·爬虫
面试鸭4 分钟前
离谱!买个人信息买到网安公司头上???
java·开发语言·职场和发展
小白学大数据5 分钟前
JavaScript重定向对网络爬虫的影响及处理
开发语言·javascript·数据库·爬虫
Python大数据分析@8 分钟前
python操作CSV和excel,如何来做?
开发语言·python·excel
黑叶白树9 分钟前
简单的签到程序 python笔记
笔记·python
Shy96041823 分钟前
Bert完形填空
python·深度学习·bert
上海_彭彭33 分钟前
【提效工具开发】Python功能模块执行和 SQL 执行 需求整理
开发语言·python·sql·测试工具·element
3345543242 分钟前
element动态表头合并表格
开发语言·javascript·ecmascript
沈询-阿里1 小时前
java-智能识别车牌号_基于spring ai和开源国产大模型_qwen vl
java·开发语言
zhongcx011 小时前
使用Python查找大文件的实用脚本
python