算子开发 Overwrite 覆盖/替换模式 scanter_nd ,Accumulate 累加模式,性能对比 memset错误 bat_alloc 错误 scanter_nd_add
scanter_nd_add Accumulate 累加模式,如果没有设置 memset,会导致 bat_alloc 错误
这两个模式(Overwrite 覆盖模式 和 Accumulate 累加模式 )是 ScatterND 算子在处理"如何将更新数据写入到输出张量"时的两种不同语义。
简单来说:覆盖就是"替换",累加就是"相加"。
下面为你详细拆解这两种模式的原理、应用场景,以及为什么累加模式会引发你遇到的 MemSet 报错。
1. Overwrite 模式 (scatterMode = 0) ------ 覆盖/替换 scanter_nd
原理:
无论目标位置原来是什么值,直接用新的 updates 数据把它替换掉 。如果多个更新指向同一个位置,最后执行的更新会覆盖前面的(或者具体行为依赖于硬件执行顺序,但语义上是"赋值")。
举个通俗的例子:
假设输出张量初始是 [0, 0, 0, 0]。
你要把数值 5 写入到索引 1 的位置。
- 覆盖模式执行后:
[0, 5, 0, 0]。
不管原来索引 1 的位置是 0 还是别的什么,现在强行变成 5。
代码与硬件层面的表现:
因为是直接写入,所以它通常使用普通的内存拷贝或直接写入指令。
不需要预先清零:既然是覆盖,目标张量里原来的垃圾数据是什么根本无所谓,反正会被直接覆盖掉。
2. Accumulate 模式 (scatterMode = 1) ------ 累加
原理:
将新的 updates 数据与目标位置原有的值进行相加 。如果多个更新指向同一个位置,这些更新的值会被累加起来。
举个通俗的例子:
假设输出张量初始是 [0, 0, 0, 0]。
你要把数值 5 加到索引 1 的位置。
- 累加模式执行后:
[0, 5, 0, 0](0 + 5 = 5)。
如果此时你再把数值3加到索引1的位置: - 结果变成:
[0, 8, 0, 0](5 + 3 = 8)。
代码与硬件层面的表现(重点!这就是你报错的根源):
- 必须清零 :既然是累加,输出张量在累加开始前,绝对必须全是 0。如果输出内存里是随机乱码(比如未初始化的显存),那加出来的结果就完全错乱了。
- 原子操作:在 NPU/GPU 并行计算时,多个核可能同时往同一个地址加数据。为了保证不丢数据,必须使用硬件提供的"原子加"指令。
3. 为什么 Accumulate 模式会触发 MemSet 报错?
结合你提供的代码,我们再看这个流程:
当 scatterMode = 1 (Accumulate) 时:
cpp
// kernel 代码 (arch910b)
if (scatterMode_ == 1) {
// Phase 0: 清零 output (accumulate 语义需要 output 初始为 0)
if (clrBlockNum_ > 0 && blockIdx_ < clrBlockNum_) {
// 调用 InitOutput 将输出显存清零
InitOutput<PARAMS_T>(outputGm[clrBlockOffSet_], currClrBlockTilingSize_, static_cast<PARAMS_T>(0));
}
SyncAll(); // 等待所有核清零完毕
}
// 后续才会去执行累加操作...
逻辑链条:
- 910b 强制设定
scatterMode = 1。 - 算子 kernel 发现是累加模式,为了保证计算正确,第一步必须先把输出显存全部写成 0。
InitOutput这个函数底层,在 910b 架构上,实际上是调用了 NPU 硬件的 MemSet(内存清零)指令。- Ascend C 编译器在编译算子时,发现你调用了
InitOutput,就会去 OPP 包里找MemSet的底层二进制文件(MemSet_dynamic_AtomicAddrClean_1_ascend910b.o)。 - 因为你的环境/包里没有这个二进制文件,编译或运行时就会报
NnopbaseGetMemsetBinInfo找不到的错误。
反观 950 (arch35) 当scatterMode = 0(Overwrite) 时:
cpp
// kernel 代码 (arch35)
if (scatterMode == 1) { // 0 != 1, 条件不成立
// 这里面的 InitOutput 根本不会执行
InitOutput<...>(...);
SyncAll();
}
因为覆盖模式不需要初始值为 0,所以它跳过了清零步骤,自然就不会去调用 MemSet 指令,也就不会去找那个缺失的二进制文件,从而不报错。
总结
- Overwrite (0) : 直接把新值怼进去,不管旧值是啥。不需要清零,不调
MemSet,不报错。 - Accumulate (1) : 把新值和旧值相加。为了保证正确,必须先把旧值(显存垃圾数据)清零。调用了
InitOutput-> 触发了MemSet指令 -> 找不到底层.o文件 -> 报错。
所以你现在的矛盾点在于:910b 想要实现功能正确的累加,就必须清零,但环境里偏偏缺了清零需要的底层工具。要么补齐工具,要么改写算子用别的方式清零(比如自己写个 for 循环强行写 0,但性能会差很多)。
你问到了一个非常核心的问题!替换模式通常的性能更高 ,但"哪个更好"完全取决于你的业务需求。下面我为你详细解
析这两种模式的性能差异、原因以及如何选择。
🆚 两种模式的核心性能差异
为了让你快速把握核心,我用一个表格来对比它们的主要区别:
| 特性维度 | 替换模式 | 累加模式 |
|---|---|---|
| 计算流程 | 直接写入新值 | 先清零 ,再执行原子加 |
| 硬件指令 | 普通内存写入指令 | 原子加指令 (如Ascend的AtomicAdd) + 内存清零指令 |
| 内存访问 | 1次写入操作 | 1次清零 + 1次原子加 = 2次操作 |
| 性能开销 | 更低 | 更高(约1.5-3倍,取决于场景和硬件) |
| 是否需要初始化 | 否(直接覆盖) | 是(必须先清零,否则结果错误) |
| 并发安全性 | 需要上层逻辑保证 | 硬件保证原子性,天然支持安全并发 |
| 典型应用场景 | 参数更新、权重初始化、非重叠索引更新 | 梯度累加、计数器、统计直方图、需要聚合多个更新 |
🔍 深入理解性能差异的原因
1. 替换模式:简单直接
替换模式的逻辑非常简单:直接将新的数据写入到指定的位置,完全覆盖原有值。这只需要一条普通的内存写入指令。
- 硬件执行 :CPU/NPU只需执行一个
Store指令,将数据从寄存器写入内存。 - 无额外开销:不需要读取旧值,不需要复杂的同步操作,也不需要预先清零内存。
- 性能优势:在单线程或索引不冲突的多线程场景下,这是最高效的方式。
2. 累加模式:复杂且开销大
累加模式的逻辑是:将新数据与原有值相加后再写入 。为了保证在多线程并发环境下的正确性,这个"读-改-写"过程必须是原子操作。
- 硬件执行 :需要执行一条原子加指令 (如x86的
LOCK XADD,或Ascend的AtomicAdd)。这条指令的代价远高于普通加法指令。 - 缓存一致性协议的代价 :当多个核心同时尝试修改同一内存地址时,硬件需要通过缓存一致性协议(如MESI)进行协调,这会导致缓存行在核心间频繁传递,产生巨大的延迟和总线开销。这是原子操作在多核下变慢的根本原因。
- 必须先清零的额外开销 :在ScatterND的累加语义中,为了保证输出张量的初始状态正确(通常是0),必须在执行原子加之前,先对整个输出张量进行一次清零操作。这相当于额外多了一次全内存的写入操作,进一步拉低了性能。
#mermaid-svg-5BPy3QG7i911Dnod{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-5BPy3QG7i911Dnod .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-5BPy3QG7i911Dnod .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-5BPy3QG7i911Dnod .error-icon{fill:#552222;}#mermaid-svg-5BPy3QG7i911Dnod .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-5BPy3QG7i911Dnod .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-5BPy3QG7i911Dnod .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-5BPy3QG7i911Dnod .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-5BPy3QG7i911Dnod .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-5BPy3QG7i911Dnod .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-5BPy3QG7i911Dnod .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-5BPy3QG7i911Dnod .marker{fill:#333333;stroke:#333333;}#mermaid-svg-5BPy3QG7i911Dnod .marker.cross{stroke:#333333;}#mermaid-svg-5BPy3QG7i911Dnod svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-5BPy3QG7i911Dnod p{margin:0;}#mermaid-svg-5BPy3QG7i911Dnod .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-5BPy3QG7i911Dnod .cluster-label text{fill:#333;}#mermaid-svg-5BPy3QG7i911Dnod .cluster-label span{color:#333;}#mermaid-svg-5BPy3QG7i911Dnod .cluster-label span p{background-color:transparent;}#mermaid-svg-5BPy3QG7i911Dnod .label text,#mermaid-svg-5BPy3QG7i911Dnod span{fill:#333;color:#333;}#mermaid-svg-5BPy3QG7i911Dnod .node rect,#mermaid-svg-5BPy3QG7i911Dnod .node circle,#mermaid-svg-5BPy3QG7i911Dnod .node ellipse,#mermaid-svg-5BPy3QG7i911Dnod .node polygon,#mermaid-svg-5BPy3QG7i911Dnod .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-5BPy3QG7i911Dnod .rough-node .label text,#mermaid-svg-5BPy3QG7i911Dnod .node .label text,#mermaid-svg-5BPy3QG7i911Dnod .image-shape .label,#mermaid-svg-5BPy3QG7i911Dnod .icon-shape .label{text-anchor:middle;}#mermaid-svg-5BPy3QG7i911Dnod .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-5BPy3QG7i911Dnod .rough-node .label,#mermaid-svg-5BPy3QG7i911Dnod .node .label,#mermaid-svg-5BPy3QG7i911Dnod .image-shape .label,#mermaid-svg-5BPy3QG7i911Dnod .icon-shape .label{text-align:center;}#mermaid-svg-5BPy3QG7i911Dnod .node.clickable{cursor:pointer;}#mermaid-svg-5BPy3QG7i911Dnod .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-5BPy3QG7i911Dnod .arrowheadPath{fill:#333333;}#mermaid-svg-5BPy3QG7i911Dnod .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-5BPy3QG7i911Dnod .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-5BPy3QG7i911Dnod .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-5BPy3QG7i911Dnod .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-5BPy3QG7i911Dnod .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-5BPy3QG7i911Dnod .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-5BPy3QG7i911Dnod .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-5BPy3QG7i911Dnod .cluster text{fill:#333;}#mermaid-svg-5BPy3QG7i911Dnod .cluster span{color:#333;}#mermaid-svg-5BPy3QG7i911Dnod div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-5BPy3QG7i911Dnod .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-5BPy3QG7i911Dnod rect.text{fill:none;stroke-width:0;}#mermaid-svg-5BPy3QG7i911Dnod .icon-shape,#mermaid-svg-5BPy3QG7i911Dnod .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-5BPy3QG7i911Dnod .icon-shape p,#mermaid-svg-5BPy3QG7i911Dnod .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-5BPy3QG7i911Dnod .icon-shape .label rect,#mermaid-svg-5BPy3QG7i911Dnod .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-5BPy3QG7i911Dnod .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-5BPy3QG7i911Dnod .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-5BPy3QG7i911Dnod :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 累加模式
计算新值
执行原子加
AtomicAdd指令
读旧值+加新值+写回
完成
预先清零输出内存
InitOutput
替换模式
计算新值
直接写入内存
Store指令
完成
上图直观展示了两种模式的操作步骤差异。累加模式多了一次清零操作 ,并且其核心的原子加操作在硬件层面远比普通写入昂贵。
3. 一个重要的例外:高并发冲突场景
如果多个更新操作频繁地指向同一个内存地址(即索引高度冲突),那么:
- 替换模式 :最终结果取决于线程执行的时序,可能后写覆盖先写,结果不确定,通常不是我们想要的。
- 累加模式 :由于原子操作保证每次加法都不会丢失 ,最终结果是所有更新值的总和 ,这是正确且确定的。
在这种高冲突并发 场景下,累加模式虽然单次操作更慢,但它是保证正确性的唯一选择。替换模式则会导致数据竞争和结果错误。
🧭 如何选择:决策流程图
你可以根据以下流程图,根据你的具体场景来做出选择:
#mermaid-svg-w8TdBpvPBoaHKgzA{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-w8TdBpvPBoaHKgzA .error-icon{fill:#552222;}#mermaid-svg-w8TdBpvPBoaHKgzA .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-w8TdBpvPBoaHKgzA .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-w8TdBpvPBoaHKgzA .marker{fill:#333333;stroke:#333333;}#mermaid-svg-w8TdBpvPBoaHKgzA .marker.cross{stroke:#333333;}#mermaid-svg-w8TdBpvPBoaHKgzA svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-w8TdBpvPBoaHKgzA p{margin:0;}#mermaid-svg-w8TdBpvPBoaHKgzA .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-w8TdBpvPBoaHKgzA .cluster-label text{fill:#333;}#mermaid-svg-w8TdBpvPBoaHKgzA .cluster-label span{color:#333;}#mermaid-svg-w8TdBpvPBoaHKgzA .cluster-label span p{background-color:transparent;}#mermaid-svg-w8TdBpvPBoaHKgzA .label text,#mermaid-svg-w8TdBpvPBoaHKgzA span{fill:#333;color:#333;}#mermaid-svg-w8TdBpvPBoaHKgzA .node rect,#mermaid-svg-w8TdBpvPBoaHKgzA .node circle,#mermaid-svg-w8TdBpvPBoaHKgzA .node ellipse,#mermaid-svg-w8TdBpvPBoaHKgzA .node polygon,#mermaid-svg-w8TdBpvPBoaHKgzA .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-w8TdBpvPBoaHKgzA .rough-node .label text,#mermaid-svg-w8TdBpvPBoaHKgzA .node .label text,#mermaid-svg-w8TdBpvPBoaHKgzA .image-shape .label,#mermaid-svg-w8TdBpvPBoaHKgzA .icon-shape .label{text-anchor:middle;}#mermaid-svg-w8TdBpvPBoaHKgzA .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-w8TdBpvPBoaHKgzA .rough-node .label,#mermaid-svg-w8TdBpvPBoaHKgzA .node .label,#mermaid-svg-w8TdBpvPBoaHKgzA .image-shape .label,#mermaid-svg-w8TdBpvPBoaHKgzA .icon-shape .label{text-align:center;}#mermaid-svg-w8TdBpvPBoaHKgzA .node.clickable{cursor:pointer;}#mermaid-svg-w8TdBpvPBoaHKgzA .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-w8TdBpvPBoaHKgzA .arrowheadPath{fill:#333333;}#mermaid-svg-w8TdBpvPBoaHKgzA .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-w8TdBpvPBoaHKgzA .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-w8TdBpvPBoaHKgzA .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-w8TdBpvPBoaHKgzA .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-w8TdBpvPBoaHKgzA .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-w8TdBpvPBoaHKgzA .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-w8TdBpvPBoaHKgzA .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-w8TdBpvPBoaHKgzA .cluster text{fill:#333;}#mermaid-svg-w8TdBpvPBoaHKgzA .cluster span{color:#333;}#mermaid-svg-w8TdBpvPBoaHKgzA div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-w8TdBpvPBoaHKgzA .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-w8TdBpvPBoaHKgzA rect.text{fill:none;stroke-width:0;}#mermaid-svg-w8TdBpvPBoaHKgzA .icon-shape,#mermaid-svg-w8TdBpvPBoaHKgzA .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-w8TdBpvPBoaHKgzA .icon-shape p,#mermaid-svg-w8TdBpvPBoaHKgzA .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-w8TdBpvPBoaHKgzA .icon-shape .label rect,#mermaid-svg-w8TdBpvPBoaHKgzA .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-w8TdBpvPBoaHKgzA .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-w8TdBpvPBoaHKgzA .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-w8TdBpvPBoaHKgzA :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 是
是
否
是
否
否
是
否
开始选择ScatterND模式
业务语义需要
聚合多个更新吗?
(如梯度累加、计数器)
存在多个更新操作
指向同一索引吗?
(高并发冲突)
选择 **累加模式**
这是保证正确性的唯一方式
对性能要求极高
且能保证索引不冲突吗?
选择 **替换模式**
性能更优
选择 **累加模式**
更安全,语义更通用
更新操作是独立的吗?
(如参数初始化、权重更新)
选择 **替换模式**
最直接高效
需要重新审视算法设计
或考虑其他算子
最终选择:累加模式
最终选择:替换模式
⚙️ 针对你当前910b环境的建议
回到你最初的问题,在Ascend 910B上:
- 如果业务允许 :优先尝试将算子配置为替换模式 。这能避免
MemSet报错,并且通常能获得更好的性能。你需要修改tiling代码,恢复对attrs的读取,并将默认值设为0。 - 如果必须使用累加模式 (例如你的业务确实需要梯度累加或计数器语义):
- 必须确保
MemSet二进制文件存在 :检查OPP包中是否存在MemSet_dynamic_AtomicAddrClean_1_ascend910b.o文件,并且memset.json配置正确。 - 接受性能开销:这是为正确性和功能完整性支付的必要成本。
- 尝试减少冲突:如果可能,优化索引生成逻辑,尽量减少多个更新指向同一地址的概率,这能在一定程度上缓解原子操作的冲突开销。
- 必须确保
💡 小贴士 :在深度学习训练中,梯度累加 是累加模式的典型应用。它通过在多个micro-batch上累加梯度,来模拟大batch size训练,有效解决显存不足的问题。这种场景下,累加模式是不可或缺的。
标准 PyTorch 官方 API 中确实没有直接命名为 torch.scatter_nd 的函数。
这是很多从 TensorFlow 迁移到 PyTorch 的开发者常遇到的困惑。PyTorch 和 TensorFlow 在设计哲学上有所不同,PyTorch 更倾向于使用统一的高级索引 机制来处理多维度的稀疏更新,而不是单独封装一个 scatter_nd。
虽然没有直接同名的函数,但 PyTorch 提供了功能完全等价甚至更灵活的替代方案。而且,这正好能和你最初问的 Overwrite(覆盖)和 Accumulate(累加)模式 完美对应起来!
下面为你梳理如何在 PyTorch 中实现 scatter_nd 的功能。
🎯 PyTorch 中的 scatter_nd 等价替代方案
在 PyTorch 中,要实现 TensorFlow tf.scatter_nd 的效果(根据坐标点更新张量),主要有两种方式:
- 使用 torch.Tensor.index_put_ (最接近的等价物) ⭐️⭐️⭐️⭐️⭐️
这是官方推荐的、语义最接近 scatter_nd 的方法。它支持通过坐标元组进行索引,并且通过参数直接支持 Overwrite 和 Accumulate 模式!
函数签名:
Tensor.index_put_(indices, values, accumulate=False)
indices: 一个包含多个 LongTensor 的元组,每个张量代表一个维度的坐标。这正好对应 TF 中 indices 的每一列。
values: 要写入的更新值。
accumulate: 这就是你之前问的开关!
False (默认): Overwrite 覆盖模式。直接替换原值。
True: Accumulate 累加模式。将新值加到原值上(内部使用原子加,支持并发安全)。
代码对比:实现相同的 ScatterND 操作
假设我们要把值 10, 20, 30 写入到 4x4 矩阵的 (0,0), (1,2), (3,1) 位置。
TensorFlow 的做法:
import tensorflow as tf
indices = tf.constant(\[0, 0, 1, 2, 3, 1]) # 形状 3, 2
updates = tf.constant(10, 20, 30)
shape = tf.constant(4, 4)
默认是覆盖模式
result = tf.scatter_nd(indices, updates, shape)
PyTorch 的等价做法:
import torch
1. 准备目标张量 (PyTorch 需要预先存在张量,TF 是创建新张量)
target = torch.zeros(4, 4)
2. 准备索引 (需要拆分成元组,对应每个维度)
TF 的 \[0, 0, 1, 2, 3, 1] 在 PyTorch 中要拆成:
row_idx = torch.tensor(0, 1, 3)
col_idx = torch.tensor(0, 2, 1)
indices_tuple = (row_idx, col_idx) # 作为元组传入
3. 准备更新值
updates = torch.tensor(10.0, 20.0, 30.0)
--- 覆盖模式 ---
target.index_put_(indices_tuple, updates, accumulate=False)
print(target)
输出:
tensor(\[10., 0., 0., 0.,
0., 0., 20., 0.,
0., 0., 0., 0.,
0., 30., 0., 0.])
--- 累加模式 (对应你之前问的 Accumulate) ---
假设原值不是0,我们要累加
target2 = torch.ones(4, 4) # 初始全为1
target2.index_put_(indices_tuple, updates, accumulate=True)
print(target2)
输出: (1 + 10 = 11, 1 + 20 = 21, 1 + 30 = 31)
tensor(\[11., 1., 1., 1.,
1., 1., 21., 1.,
1., 1., 1., 1.,
1., 31., 1., 1.])
- 使用原生 scatter_ (适用于单维度操作) ⭐️⭐️⭐️
如果你只是沿着单一维度进行散射(最常见的是 One-Hot 编码或 embedding 的反向传播),PyTorch 原生的 scatter_ 依然非常高效。
生成 One-Hot 编码
label = torch.tensor(2, 0, 1)
one_hot = torch.zeros(3, 3).scatter_(1, label.unsqueeze(1), 1) # 沿着维度1散射
- 使用第三方库 torch_scatter (针对复杂聚合) ⭐️⭐️⭐️⭐️
如果你不仅需要累加,还需要求平均、求最大值等复杂聚合操作,特别是在图神经网络(GNN)中,PyTorch 原生的 index_put_(accumulate=True) 只支持加法。这时你需要安装 torch_scatter 库。
from torch_scatter import scatter_sum
src: 源数据, index: 索引, dim=0 表示按行聚合
output = scatter_sum(src, index, dim=0)
💡 总结:回到你最初的疑问
你正在探索的 scatter_nd_update 及其相关操作,是 TensorFlow 中处理稀疏数据更新和不规则索引赋值的核心工具。它们在实现如One-Hot编码、梯度聚合、图神经网络中的消息传递等场景中至关重要。
下面我将为你梳理 TensorFlow 中这些操作的作用、核心差异、典型应用,并提供选择建议。
🆚 核心操作对比:TensorFlow 中的 Scatter 家族
TensorFlow 提供了一系列 scatter 操作,它们的核心思想都是"根据索引,将源数据分散写入到目标张量的特定位置",但在目标张量是否预先存在、是否需要清零、重复索引如何处理等方面存在关键差异。
图解:TensorFlow Scatter 家族的操作逻辑
flowchart LR
subgraph TF_scatter_nd tf.scatter_nd
A1["indices
形状: num_updates, 2"] --> A2["创建全新的零张量
形状: shape"]
A3["updates
形状: num_updates"] --> A2
A2 --> A4"结果: 零张量被部分填充"
end
subgraph TF_scatter_nd_update tf.scatter_nd_update
B1["目标张量
必须是 tf.Variable"] --> B2["按坐标点分散
每行是一个完整的坐标"]
B3["updates
形状: num_updates"] --> B2
B2 --> B4["结果: 原地修改后的 Variable
形状不变"]
end
subgraph TF_tensor_scatter_nd_update tf.tensor_scatter_nd_update
C1["目标张量
tf.Variable 或 tf.Tensor"] --> C2["按坐标点分散
每行是一个完整的坐标"]
C3["updates
形状: num_updates"] --> C2
C2 --> C4["结果: 返回新张量
原张量不变"]
end
示例对比:实现相同效果
假设我们希望将值 10, 20, 30 分别写入到一个形状为 4, 4 的张量中的 (0, 0), (1, 2), (3, 1) 这三个位置。
在 TensorFlow 中(使用 tf.scatter_nd):
import tensorflow as tf
坐标点,每个点是一个长度为2的向量(因为是2D张量)
indices = tf.constant(\[0, 0, 1, 2, 3, 1]) # 形状 3, 2
更新值
updates = tf.constant(10, 20, 30) # 形状 3
输出张量的形状
shape = tf.constant(4, 4)
直接得到结果张量
result = tf.scatter_nd(indices, updates, shape)
result:
\[10, 0, 0, 0,
0, 0, 20, 0,
0, 0, 0, 0,
0, 30, 0, 0]
在 TensorFlow 中(使用 tf.tensor_scatter_nd_update):
import tensorflow as tf
目标张量(可以是 Variable 或 Tensor)
tensor = tf.Variable(tf.zeros(4, 4, dtype=tf.int32))
坐标点
indices = tf.constant(\[0, 0, 1, 2, 3, 1])
更新值
updates = tf.constant(10, 20, 30, dtype=tf.int32)
返回一个新张量,原tensor不变
result = tf.tensor_scatter_nd_update(tensor, indices, updates)
如果想原地更新,可以:
tensor.scatter_nd_update(indices, updates) # 对于Variable
或
tensor.assign(tf.tensor_scatter_nd_update(tensor, indices, updates)) # 对于Variable
💡 关键洞察:tf.tensor_scatter_nd_update 是最通用的形式,它接受任何张量并返回一个新张量。tf.scatter_nd_update 是其特例,要求输入是 tf.Variable 并且通常用于原地更新。tf.scatter_nd 则从零开始构建新张量。
🧩 扩展操作与高级用法
TensorFlow 还提供了其他 scatter 变体来处理不同场景:
tf.scatter_nd_sub: 对已存在张量进行减法更新。
tf.scatter_nd_add: 对已存在张量进行加法更新(需注意,这通常需要张量预先初始化为零,否则可能累加到垃圾值上)。
tf.tensor_scatter_nd_max/min: 对已存在张量进行取最大/最小值更新。
这些操作可以接受一个已存在的张量作为输入,并返回一个更新后的新张量。
🎯 如何选择与使用建议
根据你的具体需求来选择合适的工具:
flowchart LR
A开始选择Scatter操作 --> B{目标张量是否存在?};
B -- 否 --> C[使用 tf.scatter_nd
创建全新的零张量并填充];
B -- 是 --> D{需要保留原张量数据吗?};
D -- 否 --> E[使用 tf.scatter_nd_update
原地更新 Variable];
D -- 是 --> F[使用 tf.tensor_scatter_nd_update
返回新张量];
C --> G[✅ 适用于从稀疏数据构建稠密张量
初始化特定点];
E --> H[✅ 适用于更新模型参数
变量状态];
F --> I[✅ 适用于任何张量的更新
灵活性最高];
G --> J💡 注意 indices 形状与 updates 形状的匹配;
H --> K[💡 注意目标必须是 Variable
且索引不能越界];
I --> K;
- 当你需要 从稀疏数据构建稠密张量 时
首选 tf.scatter_nd。它语义清晰,从零开始构建,避免了处理原有数据的麻烦。# 从稀疏的索引和值构建稠密的 one-hot 矩阵
indices = tf.constant(\[0, 2, 4]) # 三个样本的类别索引
updates = tf.constant(1, 1, 1) # 值为1
shape = tf.constant(5, 3) # 5个样本,3个类别
dense_matrix = tf.scatter_nd(indices, updates, shape) - 当你需要 更新模型参数或变量状态 时
使用 tf.scatter_nd_update。它要求输入是 tf.Variable,适用于需要持久化更新的场景。# 假设有一个Variable存储了模型参数
params = tf.Variable(tf.random.normal(100, 10))
我们只想更新其中第5, 12, 88行
indices = tf.constant(\[5, 12, 88])
updates = tf.random.normal(3, 10)
params.scatter_nd_update(indices, updates) # 原地更新
- 当你需要 更新任何张量(包括常数) 时
使用 tf.tensor_scatter_nd_update。它最灵活,接受 Tensor 或 Variable,并返回新张量,不修改原张量。# 更新一个常数张量(通常用于测试或计算中间结果)
tensor = tf.constant(\[1, 1, 1, 1])
indices = tf.constant(\[0, 1, 1, 0])
updates = tf.constant(9, 9)
new_tensor = tf.tensor_scatter_nd_update(tensor, indices, updates)
new_tensor: \[1, 9, 9, 1], tensor 保持不变
- 当你需要 聚合操作(如累加、取最大值) 时
TensorFlow 提供了 tf.scatter_nd_add 等操作,但务必确保目标张量已正确初始化(通常为零),否则结果会错误地累加到初始值上。# 正确的梯度累加示例
gradients = tf.Variable(tf.zeros(10, 5)) # 必须初始化为0
new_grads = tf.constant(\[0.1, 0.2, 0.3, 0.4])
indices = tf.constant(\[0, 0, 1, 1])
将新梯度累加到现有梯度上
gradients.scatter_nd_add(indices, new_grads)
⚠️ 重要注意事项与常见陷阱
索引形状与更新形状的匹配:这是最常见错误来源。确保 indices 的形状为 num_updates, num_dims,而 updates 的形状为 num_updates 或与目标张量对应维度匹配的形状。例如,更新一个 4, 4 张量的单个元素,indices 应为 \[0, 0],updates 应为 10 或 \[10]。
数据类型一致性:updates 的数据类型必须与目标张量一致,否则会报错。
索引越界:indices 中的值不能超出目标张量的边界,否则会引发运行时错误。
Variable 的限制:tf.scatter_nd_update 只能用于 tf.Variable,不能用于 tf.constant 或 tf.Tensor。
重复索引的覆盖语义:所有 scatter 操作在遇到重复索引时,都是后面的更新覆盖前面的,而不是累加。如果需要累加,必须使用专门的 add 变体并确保初始值为零。
💎 总结与跨框架对照
🚨 核心记忆点:在 TensorFlow 中,scatter_nd_update 和 tensor_scatter_nd_update 都是覆盖语义。如果需要累加,必须使用 scatter_nd_add 等操作,并且务必确保目标张量已初始化为零。
希望这份梳理能帮助你更好地理解和选择 TensorFlow 中的 scatter 操作!你目前是在处理一个具体的模型还是算法呢?比如是 GNN 还是其他需要稀疏更新的场景?