Modulo by a Constant

\(\newcommand{\floor}[1]{\left\lfloor #1 \right\rfloor} \newcommand{\bmax}{ {b_\max} } \newcommand{\ceil}[1]{\left\lceil #1 \right\rceil}\) 最近做 PE 的时候也顺便在玩 Rust。我在研究 Rust 编译出的 ASM 的时候,发现一段有趣的代码:

1
2
3
4
#[inline(never)]
fn opt_mod(x: u64) -> u64 {
x % 1000_000_007
}

编译出的 ASM 长这样:

1
2
3
4
5
6
7
8
9
playground::opt_mod:
movabsq $-8543223828751151131, %rcx
movq %rdi, %rax
mulq %rcx
shrq $29, %rdx
imulq $1000000007, %rdx, %rax
subq %rax, %rdi
movq %rdi, %rax
retq

我震惊的发现,这里面居然没有除法、取模?有点意思……

我猜这里的 mulq %rcx 是会溢出的,最后高 64 位进入 %rdx。如果输入是 \(n\),那么输出为那么给出的结果为 \[ \begin{equation} y = n - \floor{\frac{n a}{2^{64 + b}}} d \label{eq:div-to-mod} \end{equation} \] 其中 \(a = 2^{64}-8543223828751151131\) (由于这是个 u64,所以要加上 \(2^{64}\)),\(b = 29\)。这里面的那个取整很像在求 \(\floor{n / d}\),但是巧妙的避免避免除法,改用一次乘法和一次位移来实现。另一件有趣的事是,它非常自信这样得出的结果为 \(\floor{n/d}\) ,不会 offset-by-1,因为这里面甚至没有判断 \(y\) 是否大于 \(d\) 或小于 0,如果大于 \(d\) 或小于 0,应该再来一次减法或加法。看起来这后面还藏着一些数学,我们来研究一下。

Disclaimer: 我对汇编几乎不了解,理解基本上是靠猜。我也不了解每个指令的时间、CPU 的流水线作业、分支预测等因素对齐性能的影响。这里贴出汇编代码只是为了确认没有奇怪的操作。

Notations

这里我们约定,我们想计算的是 \(n \bmod d = n - \floor{x/d} d\),其中 \(n, d\) 都是 u64 内的无符号整形。由于对于同一个 \(d\) ,我们有很多个 \(n\) 需要做,所以我们想对 \(d\) 做预处理,看能不能比直接用 div 这种汇编指令更快。\(N\) 定义为常数 \(2^{64}\)为方便处理,我们假定 \(d\) 不是 2 的幂。

在这篇文章中,\(\bmod\) 运算优先级低于加减乘除取负等运算,但高于括号,且 \(\bmod\) 的结果均为正数,即 \(-1 \bmod 3 = 2\)

Algorithm 1.0

很显然,方程 \(\eqref{eq:div-to-mod}\) 成立的充要条件是 \[ \begin{equation} \floor{\frac{na}{2^b N}} = \floor{n/d}, \quad \Longleftrightarrow \quad \frac{n}{d} \leq \frac{na}{2^b N} < \floor{\frac{n}{d}} + 1, \label{eq:strict} \end{equation} \] 我们将不等号最右边缩到 \(\frac{n+1}{d}\)\[ \begin{equation} \frac{n}{d} \leq \frac{na}{2^b N} < \frac{n+1}{d}. \label{eq:3} \end{equation} \] 注意到不等式的三项都是关于 \(n\) 的线性函数,所以这个不等式在区间 \([0, N - 1]\) 上成立等价于在区间两端成立。当 \(n = 0\) 时,原式显然成立,所以我们只需关注 \(n = N - 1\) 的情况: \[ \frac{N - 1}{d} \leq \frac{a (N-1)}{2^bN} < \frac{N}{d}, \quad \Longleftrightarrow \quad 2^b N \leq ad < \frac{2^bN^2}{N-1}. \] 将等式最后边再次放缩成 \(2^b (N+1)\) 后可得 \[ 2^b N \leq ad \leq 2^bN+2^b. \] 此时可以看出,\(a\) 这个不等式在 \[ \begin{equation} -2^b N \bmod d \leq 2^b \label{eq:req1} \end{equation} \] 时有解 \(\ceil{\frac{2^b N}{d}}\),否则无解。带进去算一下,Rust 选择的这组 \((a, b)\) 满足要求。注意到 \(na\) 不会溢出 u128 ,故 %rdx 这里也不会溢出,难怪如此自信。

我就用 Rust 写了一下:

1
2
3
4
5
6
7
8
9
fn algo1(n: u64) -> u64 {
const MUL: u64 = 9903520244958400485;
const SHR: u32 = 29;
const D: u64 = 1000000007;

let prod_hi = (((x as u128) * (MUL as u128)) >> 64) as u64;
let q = (prod_hi >> SHR);
n - q * D
}

得到的汇编为:(和一开始的几乎一样)

1
2
3
4
5
6
7
8
playground::Algo1::modulo:
movabsq $-8543223828751151131, %rcx
movq %rdi, %rax
mulq %rcx
shrq $29, %rdx
imulq $-1000000007, %rdx, %rax
addq %rdi, %rax
retq

仔细看一下,其实还是有小小不同的……这里面只有 7 个指令,Rust 编译器里面有 8 个,区别在于 algo1 里面都是在 %rax 上操作,最后直接返回了 %rax,而 Rust 编译器里都是在 %rdi 里面操作,最后再 mov %rdi, %rdx。为此我还特意去查了一下 Calling convention(之前学的忘光了),好像是 %rdi 里面存参数,返回值在 %rax 里,而且 %rax/%rdi 以及很多寄存器都是 callee 可以随便改的……

要计算这组 \((a, b)\) 的话,直接枚举 \(b\) 就行了,反正也不大……实验中也发现,Rust 给出的就是最小的 \(b\)。现在问题来了:我们希望 \(a < N\),这样才能在一个 u64 内存下 \(a\)。如果找不到这样的 \(a\) 怎么办?枚举了一下 \([1, 10^9]\) 里面所有奇数(从前的分析来看,这和奇偶数没啥关系),大概有 70% 的数是能找到 \(a\) 的,但是剩下的 30% 也不少。

我们还可以观察到:\(b = b_\max =\lceil \log_2 p\rceil\) 时,一定能满足 \(\eqref{eq:req1}\),但是此时的 \(a = \ceil{\frac{2^\bmax N}{d}} \geq N\) 一定存不进一个 u64,而且由于 \(d \leq 2^{\bmax} < 2d\),可得 \(a = \frac{2^\bmax N}{d} < 2N\) 就多了那么一个 bit,真是尴尬……

Algorithm 1.1

第一种方法,我们干脆再 relax \(\eqref{eq:strict}\) 。如果我们能满足 \[ \frac{n}{d} \leq \frac{na}{2^b N} < \frac{n}{d} + 1, \quad \Longleftrightarrow \quad 0 \leq \frac{na}{2^b N} - \frac{n}{d} < 1, \] 的话,那么能保证 \[ \floor{\frac{n}{d}} \leq \floor{\frac{na}{2^b N}} \leq \floor{\frac{n}{d}} + 1, \] 也就是说,我们能求得一个大概(offset-by-1)的解,带入 \(\eqref{eq:div-to-mod}\) 后求得的 \(y\) 就可能是负数,还需要一个 if 来判断。注意 \(y\) 是用一个 u64 来存,判断 \(y\) 正负需要特殊的技巧。

事实上,AtCoder 里面的 modint 就是这么处理的,不过它用带符号的数。

Algorithm 1.2

另一种方法,我们再看看 Rust 咋整的……我挑了一个 \(10^9 + 93\) 来看看 Rust 生成的 ASM 是什么样:

1
2
3
4
5
6
7
8
9
10
11
12
13
playground::opt_mod:
movabsq $1360294712801925637, %rcx
movq %rdi, %rax
mulq %rcx
movq %rdi, %rax
subq %rdx, %rax
shrq %rax
addq %rdx, %rax
shrq $29, %rax
imulq $1000000093, %rax, %rax
subq %rax, %rdi
movq %rdi, %rax
retq

对比一下之前的,这里多了四行:L5-8,我们去看看这到底是在做啥……还是令 \(a'=1360294712801925637\)\(b = 29\),于是他算的东西为 \[ \begin{equation} y = n -\floor{\frac{\floor{\frac{n - \floor{\frac{na'}{N}}}{2}} + \floor{\frac{na'}{N}}}{2^b}} d = n - \floor{\frac{n(1 + \frac{na'}{N})}{2^{b+1}}}d, \label{eq:sub-shr-add} \end{equation} \] 原来如此……我们之前的问题不是 \(a \geq N\) 出了 u64 吗,但是之前也提到了,必定存在一个 \(N \leq a < 2N\) 满足要求,那我们就令 \(a = N + a'\),其中 \(a' < N\) 就可以用一个 u64 存下来了。这里有地方需要注意:虽然 \(\floor{\frac{n - \floor{\frac{na'}{N}}}{2}} + \floor{\frac{na'}{N}}\) 在数学上等于 \(\floor{\frac{n + \floor{\frac{na'}{N}}}{2}}\) ,但是后者在 \(n\) 大的时候有可能溢出,前者不会。

对比一下 Algorithm 1.1,这里 Rust 用 4 ops 当做一个 if (以及里面的 subq),如果用 y -= (y >= p) as u64 * p 的话多一个乘法,具体会快多少就不知道了……

这里我也给个 Rust 代码:

1
2
3
4
5
6
7
8
9
fn algo1_2(n: u64) -> u64 {
const MUL: u64 = 1360294712801925637;
const SHR: u32 = 29;
const D: u64 = 1000000093;

let hi = (((n as u128) * (MUL as u128)) >> 64) as u64;
let q = (((n - hi) >> 1) + hi) >> SHR;
n - q * D
}

得到的汇编为也和之前的几乎一样:

1
2
3
4
5
6
7
8
9
10
11
12
playground::algo1_2:
movabsq $1360294712801925637, %rcx
movq %rdi, %rax
mulq %rcx
movq %rdi, %rax
subq %rdx, %rax
shrq %rax
addq %rdx, %rax
shrq $29, %rax
imulq $-1000000093, %rax, %rax
addq %rdi, %rax
retq

Algorithm 2

如果还坚持用等式 \(\eqref{eq:strict}\) 来优化,有点优化不动了。这里我们另辟蹊径:我们不再要求 \(\eqref{eq:strict}\) 成立,而是要求 \[ \begin{equation} \floor{\frac{(n+1)a}{2^b N}} = \floor{n/d}, \quad \Longleftrightarrow \quad \frac{n}{d} \leq \frac{(n+1)a}{2^b N} < \floor{\frac{n}{d}} + 1, \label{eq:add-one} \end{equation} \] 成立。我们继续把右边放缩成 \(\frac{n+1}{d}\),然后用之前同样的 trick,所有项都是线性的,则 \([0, N-1]\) 内满足要求等价于区间两端满足要求: \[ 0 \leq \frac{a}{2^b N} < \frac{1}{d} \quad \wedge \quad \frac{N-1}{d} \leq \frac{a}{2^b} < \frac{N}{d}, \] 注意到前者第二个不等号等价于后者第二个不等号,故可以合并起来,化简为 \[ 2^b N - 2^b \leq ad < 2^b N, \] 可以看出,\(a\)\[ 2^b N \bmod d \leq 2^b \] 时存在解 \(\floor{\frac{2^bN}{d}}\)。注意到这个条件和 \(\eqref{eq:req1}\) 很像。事实上,我们可以证明,当 \(b\)\(\bmax - 1\) 时,两者至少有一个满足:首先我们有:对于任意 \(b\),均有 \[ 0 < (2^b N \bmod d) + (-2^b N \bmod d) < 2d, \] 又有对于任意 \(b\) 均有 \(((2^b N \bmod d) + (-2^b N \bmod d)) \bmod d = 0 \bmod d = 0\),故有 \[ (2^{\bmax - 1} N \bmod d) + (-2^{\bmax - 1} N \bmod d) = d \leq 2^\bmax, \] 所以 \((2^{\bmax - 1} N \bmod d)\)\((-2^{\bmax - 1} N \bmod d)\) 中至少有一个不超过 \(2^{\bmax - 1}\)

接下来便不难想到,这个算法可以和 Algorithm 1.0 互补:若 Algorithm 1.0 找不到合适的 \((a, b)\),则 Algorithm 2 一定能找到合适的 \((a, b)\)。剩下的问题就是怎么有效的实现 Algorithm 2 了。

Algorithm 2 有一个问题使得我们无法直接计算 \(\floor{\frac{(n+1)a}{2^b N}}\)\(n+1\) 可能会溢出 u64。这也就是我们即将解决的问题。注意到 \(a < N\),所以 \((n+1)a \leq N(N-1)\) 不会溢出 u128。

Algorithm 2.1

第一种方法,我们使用乘法分配律 \((a + b) c = ac + bc\),将 \((n+1)a\) 拆成 \(na + a\)

1
2
3
4
5
6
7
8
9
10
fn algo2_1(n: u64) -> u64 {
const MUL: u64 = 9903519393255738626;
const SHR: u32 = 29;
const D: u64 = 1000000093;

let prod = (n as u128) * (MUL as u128) + (MUL as u128);
let prod_hi = (prod >> 64) as u64;
let q = prod_hi >> SHR;
n - q * D
}

得到的汇编为:

1
2
3
4
5
6
7
8
9
10
playground::algo2_1:
movabsq $-8543224680453812990, %rcx
movq %rdi, %rax
mulq %rcx
addq %rcx, %rax
adcq $0, %rdx
shrq $29, %rdx
imulq $-1000000093, %rdx, %rax
addq %rdi, %rax
retq

Algorithm 2.2

第二种方法,我们注意到,\(n+1\) 的溢出只会发生在 \(n = N- 1\) 这种情况下,而这种情况下,由于 \(d\) 不是 2 的幂,故有 \(d \nmid N\),即 \[ \floor{\frac{N - 1}{d}} = \floor{\frac{N}{d}}, \] 也就是说,这个时候我用 \(n\) 还是 \(n+1\) 是没有区别的,那我不加不就好了……所以 Algorithm 2.2 很简单,就是将 \(\eqref{eq:add-one}\) 换成 \[ \floor{\frac{\min(n+1, N - 1)a}{2^b N}}, \] 这里的 \(\min(n + 1, N - 1)\) 听说可以直接两个汇编指令搞定,但是我搞不出来,只能手写汇编了……以下便是代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

fn algo2_2(n: u64) -> u64 {
const MUL: u64 = 9903519393255738626;
const SHR: u32 = 29;
const D: u64 = 1000000093;

let mut saturated = n;
unsafe {
asm!(
"add {0}, 1",
"sbb {0}, 0",
inout(reg) saturated,
options(nostack),
);
}
let prod = (saturated as u128) * (MUL as u128);
let prod_hi = (prod >> 64) as u64;
let q = prod_hi >> SHR;
n - q * D
}

汇编代码为:

1
2
3
4
5
6
7
8
9
10
playground::algo2_2:
movq %rdi, %rax
addq $1, %rax
sbbq $0, %rax
movabsq $-8543224680453812990, %rcx
mulq %rcx
shrq $29, %rdx
imulq $-1000000093, %rdx, %rax
addq %rdi, %rax
retq

比较一下 Algorithm 1.2 和 Algorithm 2.2,可以发现 Algorithm 2.2 少用了几个寄存器,而且少几个 op……algo2_2 如果不加 L13 ,就会多出一个 pushq %raxpopq %rcx,我也不知道这是啥,看起来像是保存寄存器,但是不知道为啥是 popq %rcx,希望有懂的人告诉我这是在干啥……

我试图用 n.saturating_add(1) 但是汇编出来的代码不是基于 sbb 而是 cmovne

1
2
3
incq	%rcx
movq $-1, %rax
cmovneq %rcx, %rax

Algorithm 3

以上算法都是基于 \(\eqref{eq:div-to-mod}\),先算了 \(\floor{\frac{n}{d}}\) 后再算 \(n \bmod d\)。能不能绕过 \(\floor{\frac{n}{d}}\) 直接算 \(n \bmod d\) 呢?

这篇 paper 提供了一个思路。在 Algorithm 1.0 里面,如果我们找到了满足条件的 \((a, b)\),我们其实可以用来干一些事:由于 \[ \frac{n}{d} \leq \frac{na}{2^b N} < \frac{n+1}{d}, \quad \Longrightarrow \quad 0 \leq \frac{na}{2^b N} - \frac{n}{d} < \frac{1}{d}, \]\[ n \bmod{d} = d \left\{ \frac{n}{d} \right\} = \floor{d \left\{ \frac{na}{2^b N} \right\}} = \floor{\frac{d(na \bmod 2^b N)}{2^b N} }. \] 搞定。注意到等式右边上的数都很大,为 \(2^bNd\) 级。一般情况下 \(b \approx \bmax\),则上面的数有 \(Nd^2\) 这么大,所以这个算法对 u64 内的数并不实用,只适用于 u32 以内的数……作者提供的代码也是针对 u32 的……

这代码也不是很好写。我这里取 \(b = 30\),那么我们需要做 30 位 u64 和 90 位 u128 的乘法,而且结果也不是取高、低 64 位,写起来很蛋疼。当然我这里是支持了 \(n\) 在 u64 内,如果只支持 u32 的 \(n\) 应该就舒服很多了,MUL也可以在一个 u64 内就存下了。

1
2
3
4
5
6
7
8
fn algo3(n: u64) -> u64 {
const MUL: u128 = 19807038786511477253u128;
const SHR: u32 = 30;
const D: u64 = 1000000093;

let prod = (n as u128) * (MUL as u128);
((prod % (1u128 << (64 + SHR)) * D as u128) >> (64 + SHR)) as u64
}

汇编如下:注意到这里用了 3 个 mul 指令……

1
2
3
4
5
6
7
8
9
10
11
12
playground::algo3:
movabsq $1360294712801925637, %rcx
movq %rdi, %rax
mulq %rcx
addl %edx, %edi
movl $1000000093, %ecx
mulq %rcx
andl $1073741823, %edi
imulq $1000000093, %rdi, %rax
addq %rdx, %rax
shrq $30, %rax
retq

Related Work

这类方式也叫做 Barrett reduction。我在调研的时候,还找到另外一种 reduction 方式:Montgomery reduction,用的比 Barrett reduction 多。它把每个数都搞了个中间表示(\(x \mapsto xR \bmod p\)\(R > p\) 是一个常数,通常选用 2 的幂),并且他有更高效的方法处理乘法。

baihacker 的 PE 库中有一个 NTT 的 benchmark,里面比较了 FLINT/NTL/LibBF 以及 Min_25 的代码,发现 Min_25 的是最快。我去看了一下 Min_25 的代码UnsafeMod 里面的 reduce 函数和 Wikipedia 里面的 REDC 算法很像(但是不是其实就是,只不过没有归一化到 \([0, p)\) 之间罢了,可能这就是为什么这叫 UnsafeMod 的原因吧),不过我也没仔细去看这个函数到底是干啥的了 = = 这个类好奇怪啊,减法里面为啥要加 3 * mod……NTL 的源代码中提到了这篇 paper,我也没看这是在说啥……Division algorithm 的 Wikipedia 中有一小节提到了:

However, unless D itself is a power of two, there is no X and Y that satisfies the conditions above. Fortunately, (N·X)/Y gives exactly the same result as N/D in integer arithmetic even when (X/Y) is not exactly equal to 1/D, but "close enough" that the error introduced by the approximation is in the bits that are discarded by the shift operation.

然后又 ref 了 3 个 link,有一个 link 就是 NTL 提到的 Granlund-Moeller 算法。

Acknowledgement

在 Division algorithm 的 Wikipedia 中,有一个 reference 是一篇博客。我写完之后才发现有这么一个 blog,仔细一看,妈呀写的是我的超集,真是太尴尬了……不过他公式有点丑……想了想,我这也不算抄袭,毕竟这是我自己看源代码看出来的东西。后来我认真读了一下这两篇 post,把他的一些新思路也加了进去,再写了一下自己的理解。我的证明和他稍有不同,但是本质上是一样的。原文中有这么一段

As is well known (and seen in a previous post), compilers optimize unsigned division by constants into multiplication by a "magic number." But not all constants are created equal, and approximately 30% of divisors require magic numbers that are one bit too large, which necessitates special handling.

这个 30% 怎么看起来这么眼熟啊……行吧 orz……

后来我又想,既然有 fast division,那会不会有 fast modulo 呢?于是我找到了这么一篇 blog 以及对应的 paper。我也向这两篇 blog 的作者表示感谢。

Epilogue

写这篇文章用的时间远比我想象中的长,从查资料,读文献,到写代码,猜汇编,到最后整理成文字。虽然每件事都不用话太久,但是加在一起还是花了很多精力。事实上,我还有一些东西想写没写完,例如写一个线性同余发生器 benchmark 一下各个算法,例如看看 Granlund-Moeller 到底是在干啥,例如讲讲 libdivide 是怎么搞的,但是我感觉我花的时间已经够多的了。我也说不清鼓捣这些东西有什么用,对我毕业毫无用处,但是我就是感到快乐。码农的快乐,往往就是这么朴实无华且枯燥。

顺便一提:我现在发现的最方便的看 Rust 如何编译一小段代码的方式就是:独立写一个函数,加 #[inline(never)],然后 cargo rustc --release -- --emit asm 找对应函数。另外,这些优化应该都是 LLVM 做的……以及,NTL/FLINT 这些库应该都做了这些优化吧,何必自己折腾呢 orz……可惜没有好用的 Rust binding……