索引类算子的 shape 规律
设:
x.shape = (A, B, C, D)index.shape = (E, F, G)
记号说明:
1表示该轴使用index0表示该轴不使用index
例如:
0101表示y = x[:, index, :, index]1110表示y = x[index, index, index, :]
规则:
- 索引轴连续:
index.shape保留在原位置 - 索引轴不连续:
index.shape提到最前面,其余未索引轴按原顺序放后面
例子( y.shape):
0110 -> y.shape = (A, E, F, G, D)0101 -> y.shape = (E, F, G, A, C)
什么是合轴
如果相邻两轴满足:
说明这两轴在内存上是连续的,可以合成一轴
以合并最后两轴为例:
shape: (d0, d1, d2, d3) -> (d0, d1, d2 * d3)stride: (st0, st1, st2, st3) -> (st0, st1, st3)
例子:
原 shape = (3, 4, 5, 6)
原 stride = (120, 30, 6, 1)
因为:
30 = 6 * 5
6 = 1 * 6
所以最后三轴可以继续合并:
新 shape = (3, 120)
新 stride = (120, 1)