np.argmax(prob_scores, axis=1)
并不是取每一列的最大值。
而是找出每行的最大值 ,返回最大值的列索引
prob_scores=\[ 0.97875 0.021254
0.99951 0.00049028
0.98667 0.01333\]
python
pred_index = np.argmax(prob_scores, axis=1)
得到:0,0,0
pred_score = np.max(prob_scores, axis=1)
取每行的最大值