1+ #pragma once
2+
3+ #include " modint/modint.hpp"
4+ #include " fft/ntt.hpp"
5+ #include " math/util.hpp"
6+
7+ namespace ConvolutionMod2_64 {
8+ using ull = unsigned long long ;
9+ static constexpr ull M1 = 645922817 ;
10+ static constexpr ull M2 = 754974721 ;
11+ static constexpr ull M3 = 880803841 ;
12+ static constexpr ull M4 = 897581057 ;
13+ static constexpr ull M5 = 998244353 ;
14+ static constexpr ull M12M4 = M1 * M2 % M4;
15+ static constexpr ull M12M5 = M1 * M2 % M5;
16+ static constexpr ull M123M5 = M12M5 * M3 % M5;
17+ static constexpr ull M12 = M1 * M2;
18+ static constexpr ull M123 = M12 * M3;
19+ static constexpr ull M1234 = M123 * M4;
20+ static constexpr ull I2 = Math::inv_mod(M1, M2);
21+ static constexpr ull I3 = Math::inv_mod(M1 * M2 % M3, M3);
22+ static constexpr ull I4 = Math::inv_mod(M1 * M2 % M4 * M3 % M4, M4);
23+ static constexpr ull I5 = Math::inv_mod(M1 * M2 % M5 * M3 % M5 * M4 % M5, M5);
24+
25+ using mint1 = ModInt<M1>;
26+ using mint2 = ModInt<M2>;
27+ using mint3 = ModInt<M3>;
28+ using mint4 = ModInt<M4>;
29+ using mint5 = ModInt<M5>;
30+
31+ NTT<mint1> ntt1;
32+ NTT<mint2> ntt2;
33+ NTT<mint3> ntt3;
34+ NTT<mint4> ntt4;
35+ NTT<mint5> ntt5;
36+
37+ template <class mint >
38+ vector<mint> inner_mult (const vector<ull>& a, const vector<ull>& b, NTT<mint>& ntt) {
39+ constexpr unsigned int mod = mint::get_mod ();
40+ vector<mint> a1 (a.size ()), b1 (b.size ());
41+ for (int i = 0 ; i < a.size (); i++) a1[i] = a[i] % mod;
42+ for (int i = 0 ; i < b.size (); i++) b1[i] = b[i] % mod;
43+ mint c = ntt.multiply (a1, b1)[0 ];
44+ return ntt.multiply (a1, b1);
45+ }
46+ template <class mint >
47+ vector<mint> inner_middle_prod (const vector<ull>& a, const vector<ull>& b, NTT<mint>& ntt) {
48+ constexpr unsigned int mod = mint::get_mod ();
49+ vector<mint> a1 (a.size ()), b1 (b.size ());
50+ for (int i = 0 ; i < a.size (); i++) a1[i] = a[i] % mod;
51+ for (int i = 0 ; i < b.size (); i++) b1[i] = b[i] % mod;
52+ return ntt.middle_product (a1, b1);
53+ }
54+ vector<ull> multiply (const vector<ull>& a, const vector<ull>& b) {
55+ if (a.empty () || b.empty ()) return {};
56+ auto c1 = inner_mult (a, b, ntt1);
57+ auto c2 = inner_mult (a, b, ntt2);
58+ auto c3 = inner_mult (a, b, ntt3);
59+ auto c4 = inner_mult (a, b, ntt4);
60+ auto c5 = inner_mult (a, b, ntt5);
61+ vector<ull> c (a.size () + b.size () - 1 , 0 );
62+ for (int i = 0 ; i < c.size (); i++) {
63+ ull y1 = c1[i].val ();
64+ ull y2 = (c2[i].val () + M2 - y1) * I2 % M2;
65+ ull y3 = (c3[i].val () + M3 - (y1 + y2 * M1) % M3) * I3 % M3;
66+ ull y4 = (c4[i].val () + M4 - (y1 + y2 * M1 + y3 * M12M4) % M4) * I4 % M4;
67+ ull y5 = (c5[i].val () + M5 - (y1 + y2 * M1 + y3 * M12M5 + y4 * M123M5) % M5) * I5 % M5;
68+ c[i] = y1 + y2 * M1 + y3 * M12 + y4 * M123 + y5 * M1234;
69+ }
70+ return c;
71+ }
72+ vector<ull> middle_product (const vector<ull>& a, const vector<ull>& b) {
73+ if (b.empty () || a.size () > b.size ()) return {};
74+ auto c1 = inner_middle_prod (a, b, ntt1);
75+ auto c2 = inner_middle_prod (a, b, ntt2);
76+ auto c3 = inner_middle_prod (a, b, ntt3);
77+ auto c4 = inner_middle_prod (a, b, ntt4);
78+ auto c5 = inner_middle_prod (a, b, ntt5);
79+ vector<ull> c (c1.size (), 0 );
80+ for (int i = 0 ; i < c.size (); i++) {
81+ ull y1 = c1[i].val ();
82+ ull y2 = (c2[i].val () + M2 - y1) * I2 % M2;
83+ ull y3 = (c3[i].val () + M3 * 2 - (y1 + y2 * M1 % M3)) * I3 % M3;
84+ ull y4 = (c4[i].val () + M4 * 3 - (y1 + y2 * M1 + y3 * M12M4) % M4) * I4 % M4;
85+ ull y5 = (c5[i].val () + M5 * 4 - (y1 + y2 * M1 + y3 * M12M5 + y4 * M123M5) % M5) * I5 % M5;
86+ c[i] = y1 + y2 * M1 + y3 * M12 + y4 * M123 + y5 * M1234;
87+ }
88+ return c;
89+ }
90+ }; // namespace ConvolutionMod2_64
91+
92+ /* *
93+ * @brief 畳み込み mod 2^64
94+ */
0 commit comments