const fn mod_pow(x: u32, y: u32, p: u32, r: u32, z: u32) -> u32Examples found in repository?
crates/competitive/src/math/number_theoretic_transform.rs (lines 62-68)
51 const PRIMITIVE_ROOT: u32 = {
52 let mut g = 3u32;
53 loop {
54 let mut ok = true;
55 let mut d = 1u32;
56 while d * d < Self::MOD {
57 if (Self::MOD - 1) % d == 0 {
58 let ds = [d, (Self::MOD - 1) / d];
59 let mut i = 0;
60 while i < 2 {
61 ok &= ds[i] == Self::MOD - 1
62 || mod_pow(
63 reduce(g as u64 * Self::N2 as u64, Self::MOD, Self::R),
64 ds[i],
65 Self::MOD,
66 Self::R,
67 Self::N1,
68 ) != Self::N1;
69 i += 1;
70 }
71 }
72 d += 1;
73 }
74 if ok {
75 break;
76 }
77 g += 2;
78 }
79 g
80 };
81 const RANK: u32 = (Self::MOD - 1).trailing_zeros();
82 const INFO: NttInfo = NttInfo::new::<Self>();
83}
84
85#[derive(Debug, PartialEq)]
86pub struct NttInfo {
87 root: [u32; 32],
88 inv_root: [u32; 32],
89 rate2: [u32; 32],
90 inv_rate2: [u32; 32],
91 rate3: [u32; 32],
92 inv_rate3: [u32; 32],
93}
94impl NttInfo {
95 const fn new<M>() -> Self
96 where
97 M: Montgomery32NttModulus,
98 {
99 let mut root = [0; 32];
100 let mut inv_root = [0; 32];
101 let mut rate2 = [0; 32];
102 let mut inv_rate2 = [0; 32];
103 let mut rate3 = [0; 32];
104 let mut inv_rate3 = [0; 32];
105 let rank = M::RANK as usize;
106
107 let g = reduce(M::PRIMITIVE_ROOT as u64 * M::N2 as u64, M::MOD, M::R);
108 root[rank] = mod_pow(g, (M::MOD - 1) >> rank, M::MOD, M::R, M::N1);
109 inv_root[rank] = mod_pow(root[rank], M::MOD - 2, M::MOD, M::R, M::N1);
110 let mut i = rank - 1;
111 loop {
112 root[i] = mod_mul(root[i + 1], root[i + 1], M::MOD, M::R);
113 inv_root[i] = mod_mul(inv_root[i + 1], inv_root[i + 1], M::MOD, M::R);
114 if i == 0 {
115 break;
116 }
117 i -= 1;
118 }
119
120 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
121 while i < rank - 1 {
122 rate2[i] = mod_mul(root[i + 2], prod, M::MOD, M::R);
123 inv_rate2[i] = mod_mul(inv_root[i + 2], inv_prod, M::MOD, M::R);
124 prod = mod_mul(prod, inv_root[i + 2], M::MOD, M::R);
125 inv_prod = mod_mul(inv_prod, root[i + 2], M::MOD, M::R);
126 i += 1;
127 }
128
129 let (mut i, mut prod, mut inv_prod) = (0, M::N1, M::N1);
130 while i < rank - 2 {
131 rate3[i] = mod_mul(root[i + 3], prod, M::MOD, M::R);
132 inv_rate3[i] = mod_mul(inv_root[i + 3], inv_prod, M::MOD, M::R);
133 prod = mod_mul(prod, inv_root[i + 3], M::MOD, M::R);
134 inv_prod = mod_mul(inv_prod, root[i + 3], M::MOD, M::R);
135 i += 1;
136 }
137
138 NttInfo {
139 root,
140 inv_root,
141 rate2,
142 inv_rate2,
143 rate3,
144 inv_rate3,
145 }
146 }