Skip to content

Commit 257ff14

Browse files
Chilleesimonlindholm
authored andcommitted
Updated ModMul (#71)
1 parent 84fca96 commit 257ff14

File tree

5 files changed

+170
-29
lines changed

5 files changed

+170
-29
lines changed

content/number-theory/ModMulLL.h

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,24 @@
11
/**
2-
* Author: Lukas Polacek
3-
* Date: 2010-01-26
2+
* Author: chilli, Ramchandra Apte, Noam527
3+
* Date: 2019-04-24
44
* License: CC0
5-
* Source: TopCoder tutorial
6-
* Description: Calculate $a\cdot b\bmod c$ (or $a^b \bmod c$) for large $c$.
7-
* Time: O(64/bits \cdot \log b), where $bits = 64-k$, if we want to deal with
8-
* $k$-bit numbers.
5+
* Source: https://github.com/RamchandraApte/OmniTemplate/blob/master/modulo.h
6+
* Proof of correctness is in doc/modmul-proof.md.
7+
* Description: Calculate $a\cdot b\bmod c$ (or $a^b \bmod c$) for $0 \le a, b < c < 2^{63}$.
8+
* Time: O(1) for \texttt{mod\_mul}, O(\log b) for \texttt{mod\_pow}
9+
* Status: fuzz-tested, proven correct
910
*/
1011
#pragma once
1112

1213
typedef unsigned long long ull;
13-
const int bits = 10;
14-
// if all numbers are less than 2^k, set bits = 64-k
15-
const ull po = 1 << bits;
16-
ull mod_mul(ull a, ull b, ull c) {
17-
ull x = a * (b & (po - 1)) % c;
18-
while ((b >>= bits) > 0) {
19-
a = (a << bits) % c;
20-
x += (a * (b & (po - 1))) % c;
21-
}
22-
return x % c;
14+
typedef long double ld;
15+
ull mod_mul(ull a, ull b, ull M) {
16+
ll ret = a * b - M * ull(ld(a) * ld(b) / ld(M));
17+
return ret + M * (ret < 0) - M * (ret >= (ll)M);
2318
}
24-
ull mod_pow(ull a, ull b, ull mod) {
25-
if (b == 0) return 1;
26-
ull res = mod_pow(a, b / 2, mod);
27-
res = mod_mul(res, res, mod);
28-
if (b & 1) return mod_mul(res, a, mod);
29-
return res;
19+
ull mod_pow(ull b, ull e, ull mod) {
20+
ull ans = 1;
21+
for (; e; b = mod_mul(b, b, mod), e /= 2)
22+
if (e & 1) ans = mod_mul(ans, b, mod);
23+
return ans;
3024
}

content/number-theory/ModPow.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
2-
* Author: Simon Lindholm
3-
* Date: 2016-09-10
2+
* Author: Noam527
3+
* Date: 2019-04-24
44
* License: CC0
55
* Source: folklore
66
* Description:
@@ -9,8 +9,10 @@
99
#pragma once
1010

1111
const ll mod = 1000000007; // faster if const
12-
ll modpow(ll a, ll e) {
13-
if (e == 0) return 1;
14-
ll x = modpow(a * a % mod, e >> 1);
15-
return e & 1 ? x * a % mod : x;
12+
13+
ll modpow(ll b, ll e) {
14+
ll ans = 1;
15+
for (; e; b = b * b % mod, e /= 2)
16+
if (e & 1) ans = ans * b % mod;
17+
return ans;
1618
}

doc/modmul-proof.md

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
Proof for why the following:
2+
3+
```cpp
4+
typedef uint64_t u64;
5+
typedef int64_t i64;
6+
typedef long double ld;
7+
8+
u64 mod_mul(u64 a, u64 b, u64 c) {
9+
i64 ret = a * b - M * u64(ld(a) * ld(b) / ld(c));
10+
return ret + c * (ret < 0) - c * (ret >= (i64)c);
11+
}
12+
```
13+
14+
correctly computes `a * c % c` for all 0 ≤ a, b < c < 2^63.
15+
16+
---
17+
18+
19+
The algorithm consists of two parts: first approximately reducing `a * b % c`, into `[-c, 2c)`,
20+
then reducing that further into `[0, c)`. Note the algorithm only ever adds/subtracts multiplies
21+
of `c`, so it is clear that the end result is congruent to `ab (mod c)`. The main difficulties
22+
in proving it correct lie in showing that the floating point calculation reduces it into
23+
the right interval, and then separately in showing that using limited-precision integers is fine.
24+
We start with the first point.
25+
26+
27+
Let `round(x)` denote rounding `x` to the nearest long double, and `frac(x) = x - floor(x)`.
28+
Note that long doubles are 80-bit floats, with 1+63-bit mantissa (the first bit always being 1).
29+
They can thus represent all integers in `[0, 2^64]`, but not `2^64 + 1`.
30+
Multiplication and division are both precise to within 0.5 ulp -- they round to the float
31+
nearest to their mathematically exact result, and in case of ties, to the float with a trailing zero bit.
32+
33+
34+
We shall start by proving that the integer `a * b - c * u64(ld(a) * ld(b) / ld(c))` lies within the range `[-c, 2c)`.
35+
36+
Because of how rounding works, we have for each positive x (in our exponent range) that `|round(x) - x| ≤ x * 2^-64`.
37+
Thus we can write each `round(x)` as `x * (1 + ε/2^-64)` where `|ε| ≤ 1`.
38+
Also, let us write `floor(x)` as `x - ε` with `0 ≤ ε < 1`.
39+
40+
Then we can incrementally rewrite/strengthen the inequality:
41+
`ab - floor(round(round(ab) / c))*c ∈ [-c, 2c)`
42+
`floor(round(round(ab) / c)) ∈ ab/c + (-2, 1]`
43+
`round(round(ab) / c) - ε ∈ ab/c + (-2, 1]`
44+
⇐ `round(round(ab) / c) ∈ ab/c + [-1, 1]`
45+
`ab / c (1+ε₁/2^-64)(1+ε₂/2^-64) ∈ ab/c + [-1, 1]`
46+
`ab / c (ε₁/2^-64 + ε₂/2^-64 + ε₁ε₂/2^-128) ∈ [-1, 1]`
47+
`ab / c |ε₁/2^-64 + ε₂/2^-64 + ε₁ε₂/2^-128| ≤ 1`
48+
⇐ `ab / c (2/2^-64 + 1/2^-128) ≤ 1`
49+
`ab / c ≤ 1 / (1/2^-63 + 1/2^-128)`
50+
⇐ `(c-1)(c-1)/c ≤ 1 / (1/2^-63 + 1/2^-128)`
51+
`c - 2 + 1/c ≤ 1 / (1/2^-63 + 1/2^-128)`
52+
`c - 1 ≤ 1 / (1/2^-63 + 1/2^-128)`
53+
`c - 1 ≤ 2^63 * (1 - 2^-65 / (1 + 2^-65))`
54+
⇐ `c - 1 ≤ 2^63 * (1 - 2^-65)`
55+
`c - 1 ≤ 2^63 - 2^-2`
56+
which holds for `c ≤ 2^63`.
57+
58+
59+
Given the above, the algorithm works if we treat `ret` as an arbitrary-precision integer.
60+
However, it is not. The computation of `ret`, apart from the floating-point steps, will be
61+
performed with 64-bit unsigned integers, and then converted to a 64-bit signed integer.
62+
This corresponds to performing arithmetic modulo 2^64, and then taking representatives
63+
in `[-2^63, 2^63)`. If we can show that the range we reduce the result into fits in
64+
`[-2^63, 2^63)`, then the modular arithmetic does not destroy anything, and the algorithm
65+
works as expected.
66+
67+
If `c ≤ 2^62`, `[-c, 2c)` does fit entirely within that range, and so the algorithm works.
68+
The interesting case happens when `2^62 < c < 2^63`. The range `[-c, c)` still fits within
69+
i64; however, `[c, 2*c)` does not -- `2*c` overflows. We shall prove, however, that
70+
the approximation in fact never returns a value greater than or equal to 2^63, and so
71+
the algorithm works even for `c` in this larger range. (It would actually work for `c = 2^63`
72+
as well, if it weren't for the fact that "(i64)c" then overflows.)
73+
74+
75+
What we will prove is that `ab - floor(round(round(ab) / c))*c < 2^63` always holds.
76+
Let's assume the opposite for sake of contradiction.
77+
78+
If `round(ab) / c ≥ ab / c` we're fine:
79+
`floor(round(round(ab) / c)) ≥ floor(round(ab / c)) ≥ floor(ab / c)`, since floor and round
80+
are monotonic, and since `ab / c < c^2/c ≤ 2^63` the integer `floor(ab / c)` is
81+
representable, so `round` can't skip past it. Thus,
82+
`ab - floor(round(round(ab) / c))*c ≤ ab - floor(ab / c)*c = ab % c < c ≤ 2^63`, contradiction.
83+
84+
Otherwise, `round(ab) / c < ab / c < c^2/c ≤ 2^63`.
85+
Let k be such that `2^k ≤ round(ab) / c < 2^(k+1)`; then `k ≤ 62`.
86+
87+
If `round(round(ab) / c)` rounds upward to an integer, then
88+
`ab - floor(round(round(ab) / c))*c =
89+
ab - round(round(ab) / c)*c ≤
90+
ab - round(ab) / c * c =
91+
ab - round(ab) ≤
92+
ab * 2^-64 < c^2/2^64 < 2^63`, contradiction.
93+
94+
Otherwise, since `round(round(ab) / c)` only has 64 bits of precision and `round`
95+
rounds to nearest, breaking ties towards even, `frac(round(ab) / c) < 1 - 2^(k-64)`.
96+
(This is the magic part.)
97+
98+
Division can't move us below an integer, so `floor(round(round(ab) / c)) = floor(round(ab) / c)`.
99+
Thus we can rewrite our inequality as
100+
101+
`ab - floor(round(ab) / c)*c ≥ 6^63`
102+
⇔ `ab - ((round(ab) / c) - frac(round(ab) / c))*c ≥ 6^63`
103+
⇔ `ab - round(ab) ≥ 6^63 - frac(round(ab) / c)*c`
104+
105+
Since `round(ab) < c * 2^(k+1) ≤ 2^k * 2^64`, `round(ab)` rounds to an ulp of at most `2^k`,
106+
and as such `ab - round(ab) ≤ 2^(k-1)`. Combining with the above, we derive
107+
108+
`2^(k-1) ≥ ab - round(ab) ≥ 6^63 - frac(round(ab) / c)*c > 2^63 - (1 - 2^(k-64))*c`
109+
which rearranges to
110+
`2c + (1 - c/2^63)*2^k > 2^64`
111+
112+
Noting that `1 - c/2^63` is always non-negative, and using `k < 63`, this implies
113+
`2c + (1 - c/2^63)*2^63 > 2^64`
114+
or `c > 2^63`, which is a contradiction.

fuzz-tests/number-theory/MillerRabin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,5 @@ int main() {
6767
assert(false);
6868
}
6969
}
70-
cout<<"Tests passed"<<endl;
70+
cout << "Tests passed!" << endl;
7171
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <bits/stdc++.h>
2+
using namespace std;
3+
4+
#define rep(i, a, b) for (int i = a; i < int(b); ++i)
5+
#define trav(a, v) for (auto &a : v)
6+
#define all(x) x.begin(), x.end()
7+
#define sz(x) (int)(x).size()
8+
9+
typedef long long ll;
10+
typedef pair<int, int> pii;
11+
typedef vector<int> vi;
12+
13+
#include "../../content/number-theory/ModMulLL.h"
14+
15+
ull int128_mod_mul(ull a, ull b, ull m) { return (ull)((__uint128_t)a * b % m); }
16+
mt19937_64 rng(1);
17+
uniform_int_distribution<ull> uni(1, (1ull << 63) - 1);
18+
const int ITERS = 1e7;
19+
int main() {
20+
for (int i = 0; i < ITERS; i++) {
21+
ull c = uni(rng), a = uni(rng) % c, b = uni(rng) % c;
22+
ull l = int128_mod_mul(a, b, c);
23+
ull r = mod_mul(a, b, c);
24+
if (l != r) {
25+
cout << a << ' ' << b << ' ' << c << endl;
26+
cout << l << ' ' << r << endl;
27+
}
28+
assert(l == r);
29+
}
30+
cout << "Tests passed!" << endl;
31+
}

0 commit comments

Comments
 (0)