Number Theoretic Transform (NTT)
:::note INCLUDE
ModularArithmetic::ModInt
➝ /cp/math/modint
:::
namespace Math { template<typename Mint> class NTT { static constexpr auto MOD = Mint::MOD; static constexpr auto ROOT = Mint::primitiveRoot();
public: /** @brief: Assumes a.size() is a power of two. */ static void ntt(std::vector<Mint>& a, bool inverse) { if (a.empty()) return; auto n = static_cast<std::int32_t>(a.size()); assert((n ^ (n & -n)) == 0);
std::int32_t i = 0; for (std::int32_t j = 1; j < n - 1; ++j) { for (std::int32_t k = n >> 1; k > (i ^= k); k >>= 1); if (j < i) swap(a[i], a[j]); }
auto h = ROOT.pow((MOD - 1) / n); if (inverse) h = h.inv();
for (std::int32_t m = 1; m < n; m <<= 1) { const auto m2 = m << 1; const auto base = h.pow(n / m2); Mint w = 1; for (std::int32_t x = 0; x < m; ++x) { for (std::int32_t s = x; s < n; s += m2) { auto u = a[s], d = a[s + m] * w; a[s] = u + d; a[s + m] = u - d; } w *= base; } }
if (inverse) { const auto n_inv = Mint(n).inv(); for (auto &x : a) x *= n_inv; } }
static void conv1D(std::vector<Mint>& a, std::vector<Mint> b) { const auto conv_size = static_cast<std::int32_t>(a.size() + b.size() - 1); auto n = conv_size; while (n ^ (n & -n)) n ^= n & -n; if (n < conv_size) n <<= 1; a.resize(n); ntt(a, false); b.resize(n); ntt(b, false); for (std::int32_t i = 0; i < n; ++i) a[i] *= b[i]; ntt(a, true); a.resize(conv_size); } };}
int main() { vector<Mint> a = {1, 4, 2, 3, 4}; vector<Mint> b = {3, 2, 0, 2, 1, 2}; Math::NTT<Mint>::conv1D(a, b); for (auto &x : a) cout << x << ' '; cout << '\n';}