Random 算子范式

所有 random 算子本质上都是同一条流水线的变体:

① 确定性种子 → ② Philox RNG → ③ 归一化随机数 → ④ 数学变换 → 目标分布

RNG = Random Number Generator,随机数生成器

Philox 是 counter-based RNG,不像传统 RNG 有状态依赖链。每个核只需知道自己的 counter 起始值(通过 Skip 跳转),就能独立生成不重叠的随机数序列。核间零通信。算子之间的唯一本质区别就是最后一步”数学变换”。

① 确定性种子

// seed → key(拆成两个 uint32)
key[0] = static_cast<uint32_t>(seed);
key[1] = static_cast<uint32_t>(seed >> 32);
 
// offset → counter(拆成四个 uint32)
counter[0] = static_cast<uint32_t>(offset_lo);
counter[1] = static_cast<uint32_t>(offset_lo >> 32);
counter[2] = static_cast<uint32_t>(offset_hi);
counter[3] = static_cast<uint32_t>(offset_hi >> 32);
 
// offset 只有一个时:counterTemp = { 0, offset }
counter[0] = 0;
counter[1] = 0;
counter[2] = static_cast<uint32_t>(offset);
counter[3] = static_cast<uint32_t>(offset >> 32);

② Philox RNG

  • 输入:key[2](来自 seed)+ counter[4](来自 offset)
  • 过程:10 轮 Feistel 网络,每轮对 counter 做乘法 + 异或 key,轮间更新 key
  • 输出:4 个 uint32 伪随机数(一次调用产生 4 个)

③ 归一化随机数

方法一:位拼接法(mantissa stuffing)

// 取低23位当尾数,拼上指数,减1 → [0, 1)
uint32_t man = x & 0x7fffff;
uint32_t val = (127 << 23) | man;
float result = reinterpret_cast<float>(val) - 1.0f;
  • 只用了 uint32 的低 23 位,丢弃高 9 位
  • 产生 2^23 个均匀值
  • 可能产生 0.0

方法二:除法法(division by 2^32)

// 整个 uint32 除以 2^32,再加半步偏移 → 约 (0, 1)
constexpr float INV = 2.3283064e-10f;  // = 1 / 2^32
float result = x * INV + INV / 2.0f;
  • 用了 uint32 的全部 32 位
  • 加了半步偏移,永远不会产生 0.0 和 1.0

为什么要两种

方法二是给 Box-Muller 变换(正态分布)准备的——Box-Muller 里有 ln(u),如果 u=0 就会炸。加了偏移保证 u 永远不为零。

方法精度范围用于
位拼接23 位[0, 1)uniform 直接输出
除法32 位(0, 1)normal、dropout、exponential 等需要 log 的场景

④ 数学变换 → 目标分布

U1、U2 是归一化之后的结果:

// Z1 = sqrt(-2 * ln(U1)) * cos(2 * PI * U2)
// Z2 = sqrt(-2 * ln(U1)) * sin(2 * PI * U2)

截断正态:Box-Muller + 截断筛选 |Z| < 2 → 接受/丢弃

Skip 跳转

Philox 是个纯函数:给定 key 和 counter,输出完全确定。counter 不同,输出就不同。

所以多核并行时,只要让每个核用不同的 counter 段就行:

假设总共要生成 400 个随机数,分给 4 个核,每核 100 个:

核0: counter = 0,  1,  2,  ..., 24    → 随机数 [0,  99]
核1: counter = 25, 26, 27, ..., 49    → 随机数 [100,199]
核2: counter = 50, 51, 52, ..., 74    → 随机数 [200,299]
核3: counter = 75, 76, 77, ..., 99    → 随机数 [300,399]

Skip 做的事就是让每个核的 counter 跳到自己的起点:

// 核1 要跳过核0 的 25 组(每组产生 4 个随机数)
Skip(25);  // counter += 25

不需要核间通信,每个核算一下自己是第几个核、跳多少,就能独立工作,天然不重复。

本质就是对 counter 数组做 128 位加法(4 个 uint32 级联进位),即 counter += count:

void Skip(const uint64_t count) {
    const uint32_t countLo = static_cast<uint32_t>(count);
    uint32_t countHi = static_cast<uint32_t>(count >> 32);
 
    counter_[0] += countLo;
    if (counter_[0] < countLo) ++countHi;      // 溢出进位
 
    counter_[1] += countHi;
    if (counter_[1] < countHi) {                // 溢出进位
        if (++counter_[2] == 0) ++counter_[3];  // 继续进位
    }
}

CANN ops-math/random算子家族

https://gitcode.com/cann/ops-math/tree/master/random

家族算子变换方法
Uniform(均匀分布)stateless_random_uniform_v2/v3、random_uniform_v2、random_uniform_int_v2、dsa_random_uniformIEEE754 位拼接 → [0,1),直接输出
Normal(正态分布)stateless_random_normal_v2/v3、random_standard_normal_v2、truncated_normal_v2、dsa_random_normalBox-Muller:两个均匀数 → 一对正态数,截断版多拒绝采样
Bernoulli(伯努利)stateless_bernoulli阈值判断:uniform ≤ p → 1,否则 → 0
Dropout(随机丢弃)stateless_drop_out_gen_mask、drop_out_do_mask、drop_out_do_mask_v3/v3_d、drop_out_v3、dsa_gen_bit_mask同伯努利,gen_mask 生成 bit mask,do_mask 应用 mask
其他stateless_randperm、stateless_random_choice_with_mask、sim_thread_exponential排列(排序取下标)、带mask随机选择、指数分布(−ln(U))