Skip to main content

mod_mul

Function mod_mul 

Source
const fn mod_mul(x: u32, y: u32, p: u32, r: u32) -> u32
Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (line 42)
39const fn mod_pow(mut x: u32, mut y: u32, p: u32, r: u32, mut z: u32) -> u32 {
40    while y > 0 {
41        if y & 1 == 1 {
42            z = mod_mul(z, x, p, r);
43        }
44        x = mod_mul(x, x, p, r);
45        y >>= 1;
46    }
47    z
48}
49
50pub trait Montgomery32NttModulus: Sized + MontgomeryReduction32 {
51    const PRIMITIVE_ROOT: u32 = {
52        let mut g = 3u32;
53        loop {
54            let mut ok = true;
55            let mut d = 1u32;
56            while d * d < Self::MOD {
57                if (Self::MOD - 1) % d == 0 {
58                    let ds = [d, (Self::MOD - 1) / d];
59                    let mut i = 0;
60                    while i < 2 {
61                        ok &= ds[i] == Self::MOD - 1
62                            || mod_pow(
63                                reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
64                                ds[i],
65                                Self::MOD,
66                                Self::R,
67                                Self::N1,
68                            ) != Self::N1;
69                        i += 1;
70                    }
71                }
72                d += 1;
73            }
74            if ok {
75                break;
76            }
77            g += 2;
78        }
79        g
80    };
81    const RANK: u32 = (Self::MOD - 1).trailing_zeros();
82    const INFO: NttInfo = NttInfo::new::<Self>();
83}
84
85#[derive(Debug, PartialEq)]
86pub struct NttInfo {
87    root: [u32; 32],
88    inv_root: [u32; 32],
89    rate2: [u32; 32],
90    inv_rate2: [u32; 32],
91    rate3: [u32; 32],
92    inv_rate3: [u32; 32],
93}
94impl NttInfo {
95    const fn new<M>() -> Self
96    where
97        M: Montgomery32NttModulus,
98    {
99        let mut root = [0; 32];
100        let mut inv_root = [0; 32];
101        let mut rate2 = [0; 32];
102        let mut inv_rate2 = [0; 32];
103        let mut rate3 = [0; 32];
104        let mut inv_rate3 = [0; 32];
105        let rank = M::RANK as usize;
106
107        let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
108        root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
109        inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
110        let mut i = rank - 1;
111        loop {
112            root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
113            inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
114            if i == 0 {
115                break;
116            }
117            i -= 1;
118        }
119
120        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
121        while i < rank - 1 {
122            rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
123            inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
124            prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
125            inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
126            i += 1;
127        }
128
129        let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
130        while i < rank - 2 {
131            rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
132            inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
133            prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
134            inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
135            i += 1;
136        }
137
138        NttInfo {
139            root,
140            inv_root,
141            rate2,
142            inv_rate2,
143            rate3,
144            inv_rate3,
145        }
146    }