competitive/data_structure/
partially_retroactive_priority_queue.rs

1use super::{Associative, Bounded, Magma, MaxOperation, MinOperation, SegmentTree, Unital};
2use std::{cmp::Reverse, iter::Flatten, slice::Iter};
3
4#[derive(Debug, Clone)]
5/// max-heap
6pub struct PartiallyRetroactivePriorityQueue<T>
7where
8    T: Clone + Ord + Bounded,
9{
10    n: usize,
11    in_edges: SegmentTree<MaxOperation<(T, Reverse<usize>)>>,
12    out_edges: SegmentTree<MinOperation<(T, usize)>>,
13    flow: SegmentTree<SumMinimum>,
14}
15
16#[derive(Debug, Default, Clone, Copy)]
17struct SumMinimum {
18    sum: i32,
19    prefix_min: i32,
20    suffix_min: i32,
21}
22
23impl SumMinimum {
24    fn singleton(x: i32) -> Self {
25        Self {
26            sum: x,
27            prefix_min: 0.min(x),
28            suffix_min: 0.max(x),
29        }
30    }
31}
32
33impl Magma for SumMinimum {
34    type T = Self;
35    fn operate(x: &Self::T, y: &Self::T) -> Self::T {
36        Self {
37            sum: x.sum + y.sum,
38            prefix_min: x.prefix_min.min(x.sum + y.prefix_min),
39            suffix_min: y.suffix_min.max(x.suffix_min + y.sum),
40        }
41    }
42}
43
44impl Associative for SumMinimum {}
45
46impl Unital for SumMinimum {
47    fn unit() -> Self::T {
48        Self::default()
49    }
50}
51
52#[derive(Debug, Clone)]
53pub struct Changed<T> {
54    pub inserted: [Option<T>; 2],
55    pub removed: [Option<T>; 2],
56}
57
58impl<T> Changed<T> {
59    pub fn inserted(&self) -> Flatten<Iter<'_, Option<T>>> {
60        self.inserted.iter().flatten()
61    }
62    pub fn removed(&self) -> Flatten<Iter<'_, Option<T>>> {
63        self.removed.iter().flatten()
64    }
65}
66
67impl<T> Default for Changed<T> {
68    fn default() -> Self {
69        Self {
70            inserted: [None, None],
71            removed: [None, None],
72        }
73    }
74}
75
76impl<T> PartiallyRetroactivePriorityQueue<T>
77where
78    T: Clone + Ord + Bounded,
79{
80    pub fn new(n: usize) -> Self {
81        let in_edges = SegmentTree::new(n);
82        let out_edges = SegmentTree::new(n);
83        let flow = SegmentTree::new(n);
84        Self {
85            n,
86            in_edges,
87            out_edges,
88            flow,
89        }
90    }
91    fn update_flow(&mut self, l: usize, r: usize, x: i32) {
92        let s = self.flow.get(l).sum + x;
93        self.flow.set(l, SumMinimum::singleton(s));
94        let s = self.flow.get(r).sum - x;
95        self.flow.set(r, SumMinimum::singleton(s));
96    }
97    pub unsafe fn set_push_unchecked(&mut self, i: usize, x: T) -> Option<T> {
98        assert!(i < self.n);
99        let p = self.flow.fold(i..self.n).sum;
100        let j = if p < 0 {
101            self.flow
102                .rposition_acc(0..i, |s| s.suffix_min + p >= 0)
103                .unwrap_or(0)
104        } else {
105            i
106        };
107        let (min, k) = self.out_edges.fold(j..self.n);
108        if x <= min {
109            self.in_edges.set(i, (x.clone(), Reverse(i)));
110            return Some(x);
111        }
112        if i <= k {
113            self.update_flow(i, k, 1);
114        } else {
115            self.update_flow(k, i, -1);
116        }
117        self.out_edges.set(i, (x.clone(), i));
118        self.out_edges.clear(k);
119        self.in_edges.set(k, (min.clone(), Reverse(k)));
120        if min == T::minimum() { None } else { Some(min) }
121    }
122    pub unsafe fn unset_pop_unchecked(&mut self, i: usize) -> Option<T> {
123        assert!(i < self.n);
124        if self.out_edges.get(i) == (T::minimum(), i) {
125            self.out_edges.clear(i);
126            return None;
127        }
128        let p = self.flow.fold(i..self.n).sum;
129        let j = if p < 0 {
130            self.flow
131                .rposition_acc(0..i, |s| s.suffix_min + p >= 0)
132                .unwrap_or(0)
133        } else {
134            i
135        };
136        let (min, k) = self.out_edges.fold(j..self.n);
137        assert_ne!(k, !0);
138        if i <= k {
139            self.update_flow(i, k, 1);
140        } else {
141            self.update_flow(k, i, -1);
142        }
143        self.in_edges.clear(i);
144        self.out_edges.clear(k);
145        self.in_edges.set(k, (min.clone(), Reverse(k)));
146        if min == T::minimum() { None } else { Some(min) }
147    }
148    pub unsafe fn set_pop_unchecked(&mut self, i: usize) -> Option<T> {
149        assert!(i < self.n);
150        let p = self.flow.fold(0..=i).sum;
151        let j = if p > 0 {
152            self.flow
153                .position_acc(i + 1..self.n - 1, |s| p + s.prefix_min <= 0)
154                .unwrap_or(self.n - 1)
155        } else {
156            i
157        };
158        let (max, Reverse(k)) = self.in_edges.fold(0..=j);
159        if max == T::minimum() {
160            self.out_edges.set(i, (T::minimum(), i));
161            return None;
162        }
163        if k <= i {
164            self.update_flow(k, i, 1);
165        } else {
166            self.update_flow(i, k, -1);
167        }
168        self.in_edges.set(i, (T::minimum(), Reverse(i)));
169        self.in_edges.clear(k);
170        self.out_edges.set(k, (max.clone(), k));
171        Some(max)
172    }
173    pub unsafe fn unset_push_unchecked(&mut self, i: usize) -> Option<T> {
174        assert!(i < self.n);
175        let (max, Reverse(k)) = self.in_edges.get(i);
176        if k == i && max != T::minimum() {
177            self.in_edges.clear(i);
178            return Some(max);
179        }
180        let p = self.flow.fold(0..=i).sum;
181        let j = if p > 0 {
182            self.flow
183                .position_acc(i + 1..self.n - 1, |s| p + s.prefix_min <= 0)
184                .unwrap_or(self.n - 1)
185        } else {
186            i
187        };
188        let (max, Reverse(k)) = self.in_edges.fold(0..=j);
189        if k <= i {
190            self.update_flow(k, i, 1);
191        } else {
192            self.update_flow(i, k, -1);
193        }
194        self.out_edges.clear(i);
195        self.in_edges.clear(k);
196        self.out_edges.set(k, (max.clone(), k));
197        if max == T::minimum() { None } else { Some(max) }
198    }
199    pub fn set_no_op(&mut self, i: usize) -> Changed<T> {
200        assert!(i < self.n);
201        let mut changed = Changed::default();
202        let (max, Reverse(k)) = self.in_edges.get(i);
203        let (min, kk) = self.out_edges.get(i);
204        if k != i && kk != i {
205            return changed;
206        }
207        if i == k && max == T::minimum() || i == kk && min == T::minimum() {
208            changed.inserted[0] = unsafe { self.unset_pop_unchecked(i) };
209        } else {
210            changed.removed[0] = unsafe { self.unset_push_unchecked(i) };
211        }
212        changed
213    }
214    pub fn set_push(&mut self, i: usize, x: T) -> Changed<T> {
215        assert!(i < self.n);
216        let mut changed = self.set_no_op(i);
217        changed.inserted[1] = unsafe { self.set_push_unchecked(i, x) };
218        changed
219    }
220    pub fn set_pop(&mut self, i: usize) -> Changed<T> {
221        assert!(i < self.n);
222        let mut changed = self.set_no_op(i);
223        changed.removed[1] = unsafe { self.set_pop_unchecked(i) };
224        changed
225    }
226    #[allow(dead_code)]
227    fn check(&self) -> Vec<T> {
228        let mut pq = vec![];
229        for i in 0..self.n {
230            let (max, Reverse(k)) = self.in_edges.get(i);
231            let (min, kk) = self.out_edges.get(i);
232            if k == i {
233                if max == T::minimum() {
234                    // pop 1 element
235                } else {
236                    // push (not pop)
237                    pq.push(max);
238                }
239            } else if kk == i {
240                if min == T::minimum() {
241                    // pop 0 element
242                } else {
243                    // push (poped)
244                }
245            } else {
246                // nop
247            }
248        }
249        pq.sort_unstable();
250        pq
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::tools::Xorshift;
258    use std::collections::BinaryHeap;
259
260    #[derive(Debug, Clone, Copy)]
261    enum Query {
262        Push(i64),
263        Pop,
264    }
265
266    #[test]
267    fn test_partially_retroactive_priority_queue() {
268        let mut rng = Xorshift::default();
269        for t in 0..100 {
270            let n = rng.random(1..=100);
271            let mut a = vec![None; n];
272            let mut prpq = PartiallyRetroactivePriorityQueue::<i64>::new(n);
273            let mut pq = Vec::new();
274            for _ in 0..1000 {
275                let i = rng.random(0..n);
276                let q = if rng.gen_bool(t as f64 / 99.) {
277                    Query::Push(rng.random(-3..=3))
278                } else {
279                    Query::Pop
280                };
281                a[i] = Some(q);
282                let changed = match q {
283                    Query::Push(x) => prpq.set_push(i, x),
284                    Query::Pop => prpq.set_pop(i),
285                };
286                for &x in changed.inserted() {
287                    pq.push(x);
288                }
289                for &x in changed.removed() {
290                    if let Some(i) = pq.iter().position(|&y| x == y) {
291                        pq.remove(i);
292                    }
293                }
294                pq.sort_unstable();
295                let mut heap = BinaryHeap::new();
296                for q in &a {
297                    match q {
298                        Some(Query::Push(x)) => {
299                            heap.push(*x);
300                        }
301                        Some(Query::Pop) => {
302                            heap.pop();
303                        }
304                        None => {}
305                    }
306                }
307                let heap = heap.into_sorted_vec();
308                let pq_ = prpq.check();
309                assert_eq!(pq, heap);
310                assert_eq!(heap, pq_);
311            }
312        }
313    }
314}