索引类算子的 shape 规律

设:

  • x.shape = (A, B, C, D)
  • index.shape = (E, F, G)

记号说明:

  • 1 表示该轴使用 index
  • 0 表示该轴不使用 index

例如:

  • 0101 表示 y = x[:, index, :, index]
  • 1110 表示 y = x[index, index, index, :]

规则:

  • 索引轴连续:index.shape 保留在原位置
  • 索引轴不连续:index.shape 提到最前面,其余未索引轴按原顺序放后面

例子( y.shape):

  • 0110 -> y.shape = (A, E, F, G, D)
  • 0101 -> y.shape = (E, F, G, A, C)

什么是合轴

如果相邻两轴满足:

说明这两轴在内存上是连续的,可以合成一轴

以合并最后两轴为例:

  • shape: (d0, d1, d2, d3) -> (d0, d1, d2 * d3)
  • stride: (st0, st1, st2, st3) -> (st0, st1, st3)

例子:

原 shape  = (3, 4, 5, 6)
原 stride = (120, 30, 6, 1)
 
因为:
30 = 6 * 5
6  = 1 * 6
 
所以最后三轴可以继续合并:
 
新 shape  = (3, 120)
新 stride = (120, 1)