作为初学者,我们先从核心概念拆解开始,用最通俗的语言讲清楚「广播」和「不规则张量(RaggedTensor)」,再一步步拆解每个示例的计算过程,最后总结规律。
一、先搞懂3个基础概念
1. 张量的「维度(阶)」
张量的维度数 = 嵌套括号的层数(最外层算1层):
- 标量(比如
3):0维(无括号) - 一维张量(比如
[1,2]):1维(1层括号) - 二维张量(比如
[[1,2],[3]]):2维(2层括号) - 三维张量(比如
[[[1,2],[3,4]],[[5,6]]]):3维(3层括号)
2. 不规则张量(RaggedTensor)
普通张量的每一行元素个数必须相同(比如 [[1,2],[3,4]]),但RaggedTensor允许每行元素个数不同 (比如 [[1,2],[3]],第一行2个、第二行1个),这是理解示例的关键。
3. 广播的核心目的
逐元素运算(加减乘除)要求两个张量的「每个位置都能一一对应」,但实际中张量形状可能不同。广播就是把形状小的张量「复制扩展」,让它和大张量形状兼容,从而能逐元素运算。
二、广播的2个核心步骤(通俗版)
把官方规则翻译成新手能懂的话:
步骤1:补维度(外层补1)
如果两个张量维度数不一样,给「维度少的那个」的最外层 加维度(大小为1),直到两者维度数相同。
👉 比如:标量 3(0维)和二维RaggedTensor [[1,2],[3]](2维),先把标量补成 [[3]](2维,大小1×1),维度数就匹配了。
步骤2:扩维度(复制值)
对每一个维度,检查两个张量的「大小」:
- 如果其中一个张量在这个维度的大小是1 → 把它复制成和另一个张量一样的大小;
- 如果两个张量的大小都不是1,且不相等 → 报错(形状不兼容,无法广播)。
👉 关键:只有「大小为1的维度」能被复制扩展,否则必报错!
三、逐行拆解「能广播」的示例
示例1:标量 + 二维RaggedTensor
python
x = tf.ragged.constant([[1, 2], [3]]) # 2维RaggedTensor,形状:2 × (不规则)(第一维2行,第二行分别是2个、1个)
y = 3 # 0维标量
print(x + y) # 结果:<tf.RaggedTensor [[4, 5], [6]]>
计算步骤:
- 补维度 :y是0维,x是2维 → 给y补2层外层维度,变成
[[3]](2维,大小1×1)。 - 扩维度 :
- 第一维(行数):x的大小是2,y的大小是1 → 把y复制成2行:
[[3],[3]]; - 第二维(每行元素数):x是不规则的(2个、1个),y的大小是1 → 把y每行复制成和x一样的个数:
[[3,3],[3]];
- 第一维(行数):x的大小是2,y的大小是1 → 把y复制成2行:
- 逐元素相加 :
- 第一行:
1+3=4,2+3=5→[4,5]; - 第二行:
3+3=6→[6]。
- 第一行:
示例2:二维RaggedTensor + 二维普通张量
python
x = tf.ragged.constant([[10, 87, 12], [19, 53], [12, 32]]) # 2维,形状3×(不规则)(3行,元素数3、2、2)
y = [[1000], [2000], [3000]] # 2维普通张量,形状3×1(3行,每行1个)
print(x + y) # 结果:<tf.RaggedTensor [[1010, 1087, 1012], [2019, 2053], [3012, 3032]]>
计算步骤:
- 补维度:两者都是2维,不用补。
- 扩维度 :
- 第一维(行数):x和y都是3 → 不用动;
- 第二维(每行元素数):y是1,x是不规则的(3、2、2)→ 把y每行复制成和x一样的个数:
[[1000,1000,1000],[2000,2000],[3000,3000]];
- 逐元素相加 :
- 第一行:
10+1000=1010,87+1000=1087,12+1000=1012; - 第二行:
19+2000=2019,53+2000=2053; - 第三行:
12+3000=3012,32+3000=3032。
- 第一行:
示例3:三维RaggedTensor + 二维普通张量
python
x = tf.ragged.constant([[[1, 2], [3, 4], [5, 6]], [[7, 8]]], ragged_rank=1) # 3维,形状2×(不规则)×2(2组,第一组3个子集,第二组1个;每个子集2个元素)
y = tf.constant([[10]]) # 2维,形状1×1
print(x + y) # 结果:<tf.RaggedTensor [[[11, 12],[13, 14],[15, 16]], [[17, 18]]]>
计算步骤:
- 补维度 :y是2维,x是3维 → 给y补1层外层维度,变成
[[[10]]](3维,大小1×1×1)。 - 扩维度 :
- 第一维(组数):x是2,y是1 → 复制y成2组:
[[[10]], [[10]]]; - 第二维(每组子集数):x是不规则的(3、1),y是1 → 复制y成和x一样的子集数:
[[[10],[10],[10]], [[10]]]; - 第三维(每个子集元素数):x是2,y是1 → 复制y每个子集成2个元素:
[[[10,10],[10,10],[10,10]], [[10,10]]];
- 第一维(组数):x是2,y是1 → 复制y成2组:
- 逐元素相加 :
- 第一组:
1+10=11/2+10=12,3+10=13/4+10=14,5+10=15/6+10=16; - 第二组:
7+10=17/8+10=18。
- 第一组:
示例4:四维RaggedTensor + 一维普通张量
python
x = tf.ragged.constant([[[[1], [2]], [], [[3]], [[4]]], [[[5], [6]], [[7]]]], ragged_rank=2) # 4维,形状2×(不规则)×(不规则)×1(2组,第一组4个子集,第二组2个;子集内元素数不规则;每个元素1个值)
y = tf.constant([10, 20, 30]) # 1维,形状3
print(x + y) # 结果:<tf.RaggedTensor [[[[11,21,31],[12,22,32]], [], [[13,23,33]], [[14,24,34]]], [[[15,25,35],[16,26,36]], [[17,27,37]]]]>
计算步骤:
- 补维度 :y是1维,x是4维 → 给y补3层外层维度,变成
[[[[10,20,30]]]](4维,大小1×1×1×3)。 - 扩维度 :
- 前三维:把y复制成和x的不规则维度完全匹配(2组、对应子集数、对应子子集数);
- 第四维:x是1,y是3 → 把x的每个元素(比如
[1])复制成3个值([1,1,1]);
- 逐元素相加 :
1+10=11/1+20=21/1+30=31→[11,21,31];2+10=12/2+20=22/2+30=32→[12,22,32];- 空子集保持空;
- 以此类推,最终得到示例结果。
四、拆解「不能广播」的示例(核心:维度大小既不相等,也无1)
示例1:二维RaggedTensor + 二维普通张量
python
x = tf.ragged.constant([[1, 2], [3, 4, 5, 6], [7]]) # 2维,形状3×(2、4、1)
y = tf.constant([[1,2,3,4],[5,6,7,8],[9,10,11,12]]) # 2维,形状3×4
# x + y 报错
原因:
- 第一维(行数):都是3 → 没问题;
- 第二维(每行元素数):x是2、4、1,y是4 → 第一行x=2≠4且≠1,无法复制匹配 → 报错。
示例2:两个二维RaggedTensor
python
x = tf.ragged.constant([[1,2,3],[4],[5,6]]) # 2维,形状3×(3、1、2)
y = tf.ragged.constant([[10,20],[30,40],[50]]) # 2维,形状3×(2、2、1)
# x + y 报错
原因:
- 第一维(行数):都是3 → 没问题;
- 第二维(每行元素数):x=3、1、2,y=2、2、1 → 第一行x=3≠y=2且都≠1 → 无法复制匹配 → 报错。
示例3:三维RaggedTensor + 三维RaggedTensor
python
x = tf.ragged.constant([[[1,2],[3,4],[5,6]], [[7,8],[9,10]]]) # 3维,形状2×(3、2)×2
y = tf.ragged.constant([[[1,2,0],[3,4,0],[5,6,0]], [[7,8,0],[9,10,0]]]) # 3维,形状2×(3、2)×3
# x + y 报错
原因:
- 前两维都匹配 → 没问题;
- 第三维(每个子集元素数):x=2,y=3 → 都≠1且不相等 → 无法复制匹配 → 报错。
五、新手必记的广播核心规律
- 补维度只补外层 :比如0维标量补成2维是
[[3]],不是[3,]; - 扩维度只复制「大小为1」的维度:只有维度大小是1时,才能复制成目标大小;
- 不规则张量的兼容条件:不规则维度的「大小」要么和另一个张量相等,要么另一个张量在该维度大小为1;
- 报错唯一原因:某维度上,两个张量的大小既不相等,也没有一个是1 → 无法广播。
记住这4条,再回头看示例,就能清晰理解每一步的计算逻辑了。