competitive/data_structure/
submask_range_query.rs

1use super::{BitDpExt, Group, Xorshift};
2use std::fmt::{self, Debug};
3
4#[derive(Debug, Clone, Copy)]
5pub struct SubmaskRangeQuery {
6    bit_width: u32,
7    mask: [u32; 3],
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum QueryKind {
12    Get,
13    Update,
14}
15
16impl SubmaskRangeQuery {
17    pub fn new(bit_width: u32) -> Self {
18        let mut rng = Xorshift::new();
19        let mut mask = [0; 3];
20        let mut rem: Vec<_> = (0..bit_width).map(|w| w % 3).collect();
21        rng.shuffle(&mut rem);
22        for (k, r) in rem.into_iter().enumerate() {
23            mask[r as usize] |= 1 << k;
24        }
25        Self { bit_width, mask }
26    }
27
28    pub fn new_with_queries(
29        queries: impl IntoIterator<Item = (QueryKind, u32)> + ExactSizeIterator + Clone,
30    ) -> Self {
31        let bit_width = queries
32            .clone()
33            .into_iter()
34            .map(|(_, m)| 32 - m.leading_zeros())
35            .max()
36            .unwrap_or(0);
37        let mut mask = [0; 3];
38        let mut cost = vec![1u32; queries.len()];
39        for k in 0..bit_width {
40            let mut sum = [0u64; 3];
41            for ((kind, m), &c) in queries.clone().into_iter().zip(&cost) {
42                match kind {
43                    QueryKind::Get => {
44                        let b = m >> k & 1 == 0;
45                        sum[if b { 2 } else { 1 }] += c as u64;
46                    }
47                    QueryKind::Update => {
48                        let b = m >> k & 1 == 0;
49                        sum[if b { 0 } else { 2 }] += c as u64;
50                    }
51                }
52            }
53            let t = (0..3).min_by_key(|&i| sum[i]).unwrap();
54            mask[t] |= 1 << k;
55            for ((kind, m), c) in queries.clone().into_iter().zip(&mut cost) {
56                match kind {
57                    QueryKind::Get => {
58                        let b = m >> k & 1 == 0;
59                        if t == if b { 2 } else { 1 } {
60                            *c <<= 1;
61                        }
62                    }
63                    QueryKind::Update => {
64                        let b = m >> k & 1 == 0;
65                        if t == if b { 0 } else { 2 } {
66                            *c <<= 1;
67                        }
68                    }
69                }
70            }
71        }
72        Self { bit_width, mask }
73    }
74
75    pub fn builder<G>() -> SubmaskRangeQueryBuilder<G>
76    where
77        G: Group,
78    {
79        SubmaskRangeQueryBuilder::new()
80    }
81
82    pub fn get_query(&self, m: u32) -> impl Iterator<Item = (u32, bool)> {
83        let fix = m & self.mask[0];
84        let sub = m & self.mask[1];
85        let sup = (!m) & self.mask[2];
86        sup.subsets().flat_map(move |s| {
87            let inv = s.count_ones() & 1 == 1;
88            sub.subsets().map(move |t| (fix | s | t, inv))
89        })
90    }
91
92    pub fn update_query(&self, m: u32) -> impl Iterator<Item = u32> {
93        let fix = m & self.mask[0] | m & self.mask[1];
94        let sup = (!m) & self.mask[0];
95        let sub = m & self.mask[2];
96        sub.subsets()
97            .flat_map(move |s| sup.subsets().map(move |t| fix | s | t))
98    }
99}
100
101#[derive(Debug, Clone)]
102enum Query<T> {
103    Get { m: u32 },
104    Update { m: u32, x: T },
105}
106
107pub struct SubmaskRangeQueryBuilder<G>
108where
109    G: Group,
110{
111    query: Vec<Query<G::T>>,
112}
113
114impl<G> Debug for SubmaskRangeQueryBuilder<G>
115where
116    G: Group,
117    G::T: Debug,
118{
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        f.debug_struct("SubmaskRangeQueryBuilder")
121            .field("query", &self.query)
122            .finish()
123    }
124}
125
126impl<G> Default for SubmaskRangeQueryBuilder<G>
127where
128    G: Group,
129{
130    fn default() -> Self {
131        Self {
132            query: Default::default(),
133        }
134    }
135}
136
137impl<G> SubmaskRangeQueryBuilder<G>
138where
139    G: Group,
140{
141    pub fn new() -> Self {
142        Default::default()
143    }
144
145    pub fn push_get(&mut self, m: u32) {
146        self.query.push(Query::Get { m });
147    }
148
149    pub fn push_update(&mut self, m: u32, x: G::T) {
150        self.query.push(Query::Update { m, x });
151    }
152
153    pub fn solve(self) -> Vec<G::T> {
154        let s = SubmaskRangeQuery::new_with_queries(self.query.iter().map(|q| match q {
155            Query::Get { m } => (QueryKind::Get, *m),
156            Query::Update { m, .. } => (QueryKind::Update, *m),
157        }));
158        let out_size = self
159            .query
160            .iter()
161            .filter(|q| matches!(q, Query::Get { .. }))
162            .count();
163        let mut out = Vec::with_capacity(out_size);
164        let mut data = vec![G::unit(); 1 << s.bit_width];
165        for q in self.query {
166            match q {
167                Query::Get { m } => {
168                    let mut f = G::unit();
169                    let mut g = G::unit();
170                    for (k, inv) in s.get_query(m) {
171                        if inv {
172                            G::operate_assign(&mut g, &data[k as usize]);
173                        } else {
174                            G::operate_assign(&mut f, &data[k as usize]);
175                        }
176                    }
177                    out.push(G::rinv_operate(&f, &g));
178                }
179                Query::Update { m, x } => {
180                    for k in s.update_query(m) {
181                        G::operate_assign(&mut data[k as usize], &x);
182                    }
183                }
184            }
185        }
186        out
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::algebra::AdditiveOperation;
194
195    #[test]
196    fn test_submask_range_query() {
197        const W: u32 = 16;
198        let mut rng = Xorshift::default();
199        let mut q = SubmaskRangeQuery::builder::<AdditiveOperation<i32>>();
200        let mut a = vec![0; 1 << W];
201        let mut exp = vec![];
202        for _ in 0..2000 {
203            if rng.gen_bool(0.5) {
204                let i = rng.rand((1 << W) as _) as u32;
205                let x = rng.rand(100) as i32;
206                q.push_update(i, x);
207                a[i as usize] += x;
208            } else {
209                let i = rng.rand((1 << W) as _) as u32;
210                q.push_get(i);
211                let mut x = 0;
212                for j in 0..1 << W {
213                    if (i & j) == j {
214                        x += a[j as usize];
215                    }
216                }
217                exp.push(x);
218            }
219        }
220        let ans = q.solve();
221        assert_eq!(ans, exp);
222    }
223
224    #[test]
225    fn test_submask_range_query_online() {
226        const W: u32 = 16;
227        let mut rng = Xorshift::default();
228        let q = SubmaskRangeQuery::new(W);
229        let mut a = vec![0; 1 << W];
230        let mut b = vec![0; 1 << W];
231        for _ in 0..2000 {
232            if rng.gen_bool(0.5) {
233                let i = rng.rand((1 << W) as _) as u32;
234                let x = rng.rand(100) as i32;
235                a[i as usize] += x;
236                for j in q.update_query(i) {
237                    b[j as usize] += x;
238                }
239            } else {
240                let i = rng.rand((1 << W) as _) as u32;
241                let mut x = 0;
242                for j in 0..1 << W {
243                    if (i & j) == j {
244                        x += a[j as usize];
245                    }
246                }
247                let mut y = 0;
248                for (j, inv) in q.get_query(i) {
249                    if inv {
250                        y -= b[j as usize];
251                    } else {
252                        y += b[j as usize];
253                    }
254                }
255                assert_eq!(x, y);
256            }
257        }
258    }
259}