描述
pytorch 基于masking对元素进行替换. 代码如下. 先展平再赋值.
代码
# map.shape [64,60,128]
# infill.shape [64,17,128]
# mask_indices.shape [64,60]
map = map.reshape(
map.shape[0] * map.shape[1],
map.shape[2]) [mask_indices.reshape(mask_indices.shape[0]*mask_indices.shape[1])] \
= fillin.reshape(fillin.shape[0]*fillin.shape[1], fillin.shape[2])