问题重新描述如下,设计程序
\[False False False True True True
True False True False False True
False False False True True True
True False False False False False
False False False False False False
True False False False False False
True True False True False False
False False False True False False
True False True False False True
False False False True True True
True True False True False False
True False True False False True
True True False True False False\]
使用 np.where查找前面布尔矩阵中,
1、如果1,4,6列有两个True,则给我返回1,4,6列值为True的行列索引,放在前两个位置(行对索引应一个矩阵,列索引对应一个矩阵),剩下的那一个没有True的则返回该元素原始索引,放在第三个位置。
2、2,3,5列中最多有一个True,返回它的行列索引,放在返回数组的第四个位置。如果1,4,6列中True的个数小于2,就返回 1,4,6列的原始行列索引(放在返回数组的1,2,3列),以及第二列的索引(放在返回数组的第4列)。
3、最终返回两个索引数组,一个是行索引数组,另一个是列索引数组,每一行的索引必须与布尔矩阵中的行相对应
python
def elegant_solution(arr):
"""
修正的向量化版本,无循环
"""
n_rows = arr.shape[0]
# 1. 行索引数组
row_array = np.repeat(np.arange(n_rows)[:, None], 4, axis=1)
# 2. 预定义列
cols_146 = np.array([0, 3, 5])
cols_235 = np.array([1, 2, 4])
# 3. 初始化列索引数组
col_array = np.zeros((n_rows, 4), dtype=int)
# 4. 计算第1,4,6列True数量
counts_146 = arr[:, cols_146].sum(axis=1)
# 5. 区分两种情况
mask_eq2 = counts_146 == 2
mask_lt2 = counts_146 < 2
# 6. 处理mask_lt2为True的情况(True个数小于2)
# 设置默认值:第1,4,6列和第2列
col_array[mask_lt2] = np.array([0, 3, 5, 1])
# 7. 处理mask_eq2为True的情况(True个数等于2)
if np.any(mask_eq2):
rows_eq2 = np.where(mask_eq2)[0]
n_eq2 = len(rows_eq2)
if n_eq2 > 0:
# 提取这些行的第1,4,6列
arr_146_sub = arr[rows_eq2][:, cols_146]
# 方法:使用argsort分离True和False
# 对每行进行排序,True(1)会排在False(0)后面
sorted_indices = np.argsort(arr_146_sub, axis=1) # 默认升序,False在前
# 提取True的位置(最后两个)和False的位置(第一个)
true_indices = sorted_indices[:, -2:] # 形状: (n_eq2, 2)
false_indices = sorted_indices[:, 0:1] # 形状: (n_eq2, 1)
# 转换为原始列索引
true_cols = cols_146[true_indices]
false_cols = cols_146[false_indices]
# 填充前三个位置
col_array[rows_eq2, 0] = true_cols[:, 0]
col_array[rows_eq2, 1] = true_cols[:, 1]
col_array[rows_eq2, 2] = false_cols[:, 0]
# 处理第2,3,5列
arr_235_sub = arr[rows_eq2][:, cols_235]
# 找到每行第一个True的位置
# 使用cumsum找到第一个True
cumsum_arr = np.cumsum(arr_235_sub, axis=1)
first_true_mask = (cumsum_arr == 1) & arr_235_sub
# 获取列索引
# argmax会返回第一个最大值的位置,我们确保每行至多一个True
first_true_idx = np.argmax(first_true_mask, axis=1)
# 检查是否找到True
has_true = np.any(first_true_mask, axis=1)
# 转换并填充
# 注意:当没有True时,argmax返回0,但我们需要-1
# 所以使用where根据has_true选择值
col_235_vals = np.where(has_true, cols_235[first_true_idx], -1)
col_array[rows_eq2, 3] = col_235_vals
return row_array, col_array