Skip to main content

Montgomery Space

#include <iostream>
#include <ranges>

namespace Montgomery {

static constinit const uint64_t MOD = 1e9 + 7;

static std::tuple<uint64_t, uint64_t> mult(uint64_t x, uint64_t y) {
uint32_t a = x >> 32, b = x, c = y >> 32, d = y;
uint64_t ac = uint64_t(a) * c;
uint64_t ad = uint64_t(a) * d;
uint64_t bc = uint64_t(b) * c;
uint64_t bd = uint64_t(b) * d;
uint64_t lo = bd + (ad << 32) + (bc << 32);
uint64_t hi = (((bd >> 32) + uint32_t(ad) + uint32_t(bc)) >> 32)
+ (ad >> 32) + (bc >> 32) + ac;
return {hi, lo};
}

uint64_t to_montogomery(uint64_t);
uint64_t from_montogomery(uint64_t);
uint64_t mult_montogomery(uint64_t, uint64_t);

static constinit const uint64_t INV = [] {
uint64_t inv = 1;
for (auto _ : std::views::iota(0, 6)) inv *= 2 - MOD * inv;
return inv;
}();

static const uint64_t R_SQ = [] {
uint64_t r_sq = -MOD % MOD;
for (auto _ : std::views::iota(0, 4)) {
r_sq <<= 1;
if (r_sq >= MOD) r_sq -= MOD;
}
for (auto _ : std::views::iota(0, 4)) r_sq = mult_montogomery(r_sq, r_sq);
return r_sq;
}();

uint64_t to_montogomery(uint64_t x) {
return mult_montogomery(x, R_SQ);
}

uint64_t from_montogomery(uint64_t x) {
return MOD - std::get<0>(mult(x * INV, MOD));
}

uint64_t mult_montogomery(uint64_t x, uint64_t y) {
auto [hi, lo] = mult(x, y);
int64_t v = int64_t(hi) - std::get<0>(mult(lo * INV, MOD));
return (v >= 0) ? v : v + MOD;
}

}

int main() {
auto a = Montgomery::to_montogomery(31);
auto b = Montgomery::to_montogomery(17);
auto an = Montgomery::mult_montogomery(a, b);
an = Montgomery::from_montogomery(an);
std::cout << an << std::endl;
}