Skip to main content

Number Theoretic Transform (NTT)

INCLUDE
namespace Math {
template<typename Mint> class NTT {
static constexpr auto MOD = Mint::MOD;
static constexpr auto ROOT = Mint::primitiveRoot();

public:
/** @brief: Assumes a.size() is a power of two.
*/
static void ntt(std::vector<Mint>& a, bool inverse) {
if (a.empty()) return;
auto n = static_cast<std::int32_t>(a.size());
assert((n ^ (n & -n)) == 0);

std::int32_t i = 0;
for (std::int32_t j = 1; j < n - 1; ++j) {
for (std::int32_t k = n >> 1; k > (i ^= k); k >>= 1);
if (j < i) swap(a[i], a[j]);
}

auto h = ROOT.pow((MOD - 1) / n);
if (inverse) h = h.inv();

for (std::int32_t m = 1; m < n; m <<= 1) {
const auto m2 = m << 1;
const auto base = h.pow(n / m2);
Mint w = 1;
for (std::int32_t x = 0; x < m; ++x) {
for (std::int32_t s = x; s < n; s += m2) {
auto u = a[s], d = a[s + m] * w;
a[s] = u + d; a[s + m] = u - d;
}
w *= base;
}
}

if (inverse) {
const auto n_inv = Mint(n).inv();
for (auto &x : a) x *= n_inv;
}
}

static void conv1D(std::vector<Mint>& a, std::vector<Mint> b) {
const auto conv_size = static_cast<std::int32_t>(a.size() + b.size() - 1);
auto n = conv_size;
while (n ^ (n & -n)) n ^= n & -n;
if (n < conv_size) n <<= 1;
a.resize(n); ntt(a, false);
b.resize(n); ntt(b, false);
for (std::int32_t i = 0; i < n; ++i) a[i] *= b[i];
ntt(a, true); a.resize(conv_size);
}
};
}

int main() {
vector<Mint> a = {1, 4, 2, 3, 4};
vector<Mint> b = {3, 2, 0, 2, 1, 2};
Math::NTT<Mint>::conv1D(a, b);
for (auto &x : a) cout << x << ' ';
cout << '\n';
}