Skip to main content

Modular Integer Class

namespace Math {
template<int32_t MOD_> class Mint {
static_assert(MOD_ > 0, "MOD must be positive");

private:
int32_t v;

public:
static constexpr int32_t MOD = MOD_;
constexpr Mint() : v(0) {}
constexpr Mint(int32_t v_) : v(v_ % MOD) { if (v < 0) v += MOD; }
constexpr Mint(int64_t v_) : v(static_cast<int32_t>(v_ % MOD)) { if (v < 0) v += MOD; }
explicit operator int() const { return v; }
friend std::ostream& operator << (std::ostream& os, const Mint& n) { return os << int(n); }
friend std::istream& operator >> (std::istream& is, Mint& n) { int64_t v_; is >> v_; n = Mint(v_); return is; }
friend bool operator == (const Mint& a, const Mint& b) { return a.v == b.v; }
friend bool operator != (const Mint& a, const Mint& b) { return a.v != b.v; }

Mint operator + () { return Mint(*this); }
Mint operator - () { Mint n; n.v = v ? MOD - v : 0; return n; }
Mint& operator ++ () { ++v; if (v == MOD) v = 0; return *this; }
Mint& operator -- () { if (v == 0) v = MOD; --v; return *this; }
Mint operator ++ (int) { Mint n(*this); operator++(); return n; }
Mint operator -- (int) { Mint n(*this); operator--(); return n; }

Mint& operator += (const Mint &o) { if ((v += o.v) >= MOD) v -= MOD; return *this; }
Mint& operator -= (const Mint &o) { if ((v -= o.v) < 0) v += MOD; return *this; }
Mint& operator *= (const Mint &o) { auto r = static_cast<int64_t>(v) * o.v; v = static_cast<int32_t>((r >= MOD) ? r % MOD : r); return *this; }
Mint& operator /= (const Mint &o) { return *this *= o.inv(); }

friend Mint operator + (Mint a, const Mint &b) { return a += b; }
friend Mint operator - (Mint a, const Mint &b) { return a -= b; }
friend Mint operator * (Mint a, const Mint &b) { return a *= b; }
friend Mint operator / (Mint a, const Mint &b) { return a /= b; }

Mint pow(int64_t b) const { Mint a(*this), r; r.v = 1; for (b %= (MOD - 1); b > 0; b >>= 1) { if (b & 1) r *= a; a *= a; } return r; }

/**
* @brief: See extended euclid algorithm iterative version to understand this logic below.
*/
Mint inv() const {
int32_t a1 = MOD, a2 = v, v1 = 0, v2 = 1;
while (a2) {
int32_t q = a1 / a2;
a1 -= q * a2; swap(a1, a2);
v1 -= q * v2; swap(v1, v2);
}
return Mint(v1);
}

/**
* @brief: works when MOD is a prime of the form 1+(p*(2^k)), where p is a prime and k > 0. (verification needed?)
*/
static Mint primitiveRoot() {
if (MOD == 998244353) return 3;
int32_t divisor_2 = 1 << std::countr_zero(static_cast<uint32_t>(MOD - 1));
int32_t phi_mod_by_2 = (MOD - 1) / 2;
Mint root = 2;
while (!(root.pow(divisor_2) != 1 && root.pow(phi_mod_by_2) != 1)) ++root;
return root;
}
};
}

constexpr int MOD = 998244353;
using Mint = Math::Mint<MOD>;