competitive/data_structure/splay_tree/
sized_map.rs

1use super::{
2    Allocator, MemoryPool,
3    node::{Node, NodeRange, NodeRef, Root, SplaySeeker, SplaySpec, marker},
4};
5use std::{
6    borrow::Borrow,
7    cmp::Ordering,
8    fmt::{self, Debug},
9    iter::FusedIterator,
10    marker::PhantomData,
11    mem::{ManuallyDrop, replace},
12    ops::{Bound, DerefMut, RangeBounds},
13};
14
15struct SizedSplay<T> {
16    _marker: PhantomData<fn() -> T>,
17}
18impl<T> SplaySpec for SizedSplay<T> {
19    type T = (T, usize);
20    fn has_bottom_up() -> bool {
21        true
22    }
23    fn bottom_up(node: NodeRef<marker::DataMut<'_>, Self>) {
24        let l = node.left().map(|p| p.data().1).unwrap_or_default();
25        let r = node.right().map(|p| p.data().1).unwrap_or_default();
26        node.data_mut().1 = l + r + 1;
27    }
28}
29
30struct SeekByKey<'a, K, V, Q>
31where
32    Q: ?Sized,
33{
34    key: &'a Q,
35    _marker: PhantomData<fn() -> (K, V)>,
36}
37impl<'a, K, V, Q> SeekByKey<'a, K, V, Q>
38where
39    Q: ?Sized,
40{
41    fn new(key: &'a Q) -> Self {
42        Self {
43            key,
44            _marker: PhantomData,
45        }
46    }
47}
48impl<K, V, Q> SplaySeeker for SeekByKey<'_, K, V, Q>
49where
50    K: Borrow<Q>,
51    Q: Ord + ?Sized,
52{
53    type S = SizedSplay<(K, V)>;
54    fn splay_seek(&mut self, node: NodeRef<marker::Immut<'_>, Self::S>) -> Ordering {
55        self.key.cmp((node.data().0).0.borrow())
56    }
57}
58
59struct SeekBySize<K, V> {
60    index: usize,
61    _marker: PhantomData<fn() -> (K, V)>,
62}
63impl<K, V> SeekBySize<K, V> {
64    fn new(index: usize) -> Self {
65        Self {
66            index,
67            _marker: PhantomData,
68        }
69    }
70}
71impl<K, V> SplaySeeker for SeekBySize<K, V> {
72    type S = SizedSplay<(K, V)>;
73    fn splay_seek(&mut self, node: NodeRef<marker::Immut<'_>, Self::S>) -> Ordering {
74        let lsize = node.left().map(|l| l.data().1).unwrap_or_default();
75        let ord = self.index.cmp(&lsize);
76        if matches!(ord, Ordering::Greater) {
77            self.index -= lsize + 1;
78        }
79        ord
80    }
81}
82
83pub struct SplayMap<K, V, A = MemoryPool<Node<((K, V), usize)>>>
84where
85    A: Allocator<Node<((K, V), usize)>>,
86{
87    root: Root<SizedSplay<(K, V)>>,
88    length: usize,
89    alloc: ManuallyDrop<A>,
90}
91
92impl<K, V, A> Debug for SplayMap<K, V, A>
93where
94    K: Debug,
95    V: Debug,
96    A: Allocator<Node<((K, V), usize)>>,
97{
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        f.debug_struct("SplayMap")
100            .field("root", &self.root)
101            .field("length", &self.length)
102            .finish()
103    }
104}
105
106impl<K, V, A> Drop for SplayMap<K, V, A>
107where
108    A: Allocator<Node<((K, V), usize)>>,
109{
110    fn drop(&mut self) {
111        unsafe {
112            while let Some(node) = self.root.take_first() {
113                self.alloc.deallocate(node.into_dying().into_inner());
114            }
115            ManuallyDrop::drop(&mut self.alloc);
116        }
117    }
118}
119
120impl<K, V, A> Default for SplayMap<K, V, A>
121where
122    A: Allocator<Node<((K, V), usize)>> + Default,
123{
124    fn default() -> Self {
125        Self {
126            root: Root::default(),
127            length: 0,
128            alloc: Default::default(),
129        }
130    }
131}
132
133impl<K, V> SplayMap<K, V> {
134    pub fn new() -> Self {
135        Default::default()
136    }
137    pub fn with_capacity(capacity: usize) -> Self {
138        Self {
139            root: Root::default(),
140            length: 0,
141            alloc: ManuallyDrop::new(MemoryPool::with_capacity(capacity)),
142        }
143    }
144}
145impl<K, V, A> SplayMap<K, V, A>
146where
147    A: Allocator<Node<((K, V), usize)>>,
148{
149    pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
150    where
151        K: Borrow<Q>,
152        Q: Ord + ?Sized,
153    {
154        self.get_key_value(key).map(|(_, v)| v)
155    }
156    fn splay_by_key<Q>(&mut self, key: &Q) -> Option<Ordering>
157    where
158        K: Borrow<Q>,
159        Q: Ord + ?Sized,
160    {
161        self.root.splay_by(SeekByKey::new(key))
162    }
163    pub fn get_key_value<Q>(&mut self, key: &Q) -> Option<(&K, &V)>
164    where
165        K: Borrow<Q>,
166        Q: Ord + ?Sized,
167    {
168        if !matches!(self.splay_by_key(key)?, Ordering::Equal) {
169            return None;
170        }
171        self.root.root().map(|node| {
172            let ((k, v), _) = node.data();
173            (k, v)
174        })
175    }
176    fn splay_at(&mut self, index: usize) -> Option<Ordering> {
177        self.root.splay_by(SeekBySize::new(index))
178    }
179    pub fn get_key_value_at(&mut self, index: usize) -> Option<(&K, &V)> {
180        if index >= self.length {
181            return None;
182        }
183        self.splay_at(index);
184        self.root.root().map(|node| {
185            let ((k, v), _) = node.data();
186            (k, v)
187        })
188    }
189    pub fn insert(&mut self, key: K, value: V) -> Option<V>
190    where
191        K: Ord,
192    {
193        let ord = self.splay_by_key(&key);
194        self.length += (ord != Some(Ordering::Equal)) as usize;
195        match ord {
196            Some(Ordering::Equal) => {
197                return Some(replace(
198                    &mut (self.root.root_data_mut().unwrap().data_mut().0).1,
199                    value,
200                ));
201            }
202            Some(Ordering::Less) => unsafe {
203                self.root.insert_left(NodeRef::from_data(
204                    ((key, value), 1),
205                    self.alloc.deref_mut(),
206                ));
207            },
208            _ => unsafe {
209                self.root.insert_right(NodeRef::from_data(
210                    ((key, value), 1),
211                    self.alloc.deref_mut(),
212                ));
213            },
214        }
215        None
216    }
217    pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
218    where
219        K: Borrow<Q>,
220        Q: Ord + ?Sized,
221    {
222        if !matches!(self.splay_by_key(key)?, Ordering::Equal) {
223            return None;
224        }
225        self.length -= 1;
226        let node = self.root.take_root().unwrap().into_dying();
227        unsafe { Some((node.into_data(self.alloc.deref_mut()).0).1) }
228    }
229    pub fn remove_at(&mut self, index: usize) -> Option<(K, V)> {
230        if index >= self.length {
231            return None;
232        }
233        self.splay_at(index);
234        self.length -= 1;
235        let node = self.root.take_root().unwrap().into_dying();
236        unsafe { Some(node.into_data(self.alloc.deref_mut()).0) }
237    }
238    pub fn len(&self) -> usize {
239        self.length
240    }
241    pub fn is_empty(&self) -> bool {
242        self.len() == 0
243    }
244    pub fn iter(&mut self) -> Iter<'_, K, V> {
245        Iter {
246            iter: NodeRange::new(&mut self.root),
247        }
248    }
249    pub fn range<Q, R>(&mut self, range: R) -> Iter<'_, K, V>
250    where
251        K: Borrow<Q>,
252        Q: Ord + ?Sized,
253        R: RangeBounds<Q>,
254    {
255        let start = match range.start_bound() {
256            Bound::Included(key) => Bound::Included(SeekByKey::new(key)),
257            Bound::Excluded(key) => Bound::Excluded(SeekByKey::new(key)),
258            Bound::Unbounded => Bound::Unbounded,
259        };
260        let end = match range.end_bound() {
261            Bound::Included(key) => Bound::Included(SeekByKey::new(key)),
262            Bound::Excluded(key) => Bound::Excluded(SeekByKey::new(key)),
263            Bound::Unbounded => Bound::Unbounded,
264        };
265        Iter {
266            iter: self.root.range(start, end),
267        }
268    }
269    pub fn range_at<R>(&mut self, range: R) -> Iter<'_, K, V>
270    where
271        R: RangeBounds<usize>,
272    {
273        let start = match range.start_bound() {
274            Bound::Included(&index) => Bound::Included(SeekBySize::new(index)),
275            Bound::Excluded(&index) => Bound::Excluded(SeekBySize::new(index)),
276            Bound::Unbounded => Bound::Unbounded,
277        };
278        let end = match range.end_bound() {
279            Bound::Included(&index) => Bound::Included(SeekBySize::new(index)),
280            Bound::Excluded(&index) => Bound::Excluded(SeekBySize::new(index)),
281            Bound::Unbounded => Bound::Unbounded,
282        };
283        Iter {
284            iter: self.root.range(start, end),
285        }
286    }
287}
288
289#[derive(Debug)]
290pub struct Iter<'a, K, V> {
291    iter: NodeRange<'a, SizedSplay<(K, V)>>,
292}
293impl<K, V> Iterator for Iter<'_, K, V>
294where
295    K: Clone,
296    V: Clone,
297{
298    type Item = (K, V);
299    fn next(&mut self) -> Option<Self::Item> {
300        self.iter.next_checked().map(|node| node.data().0.clone())
301    }
302    fn last(mut self) -> Option<Self::Item> {
303        self.next_back()
304    }
305    fn min(mut self) -> Option<Self::Item> {
306        self.next()
307    }
308    fn max(mut self) -> Option<Self::Item> {
309        self.next_back()
310    }
311}
312impl<K, V> FusedIterator for Iter<'_, K, V>
313where
314    K: Clone,
315    V: Clone,
316{
317}
318impl<K, V> DoubleEndedIterator for Iter<'_, K, V>
319where
320    K: Clone,
321    V: Clone,
322{
323    fn next_back(&mut self) -> Option<Self::Item> {
324        self.iter
325            .next_back_checked()
326            .map(|node| node.data().0.clone())
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::tools::Xorshift;
334    use std::{cell::RefCell, collections::BTreeMap, mem::swap};
335
336    impl<K, V> SplayMap<K, V> {
337        fn dump(&self) -> Vec<(K, V)>
338        where
339            K: Clone,
340            V: Clone,
341        {
342            let mut arr = vec![];
343            if let Some(root) = self.root.root() {
344                root.traverse(&mut |node| {
345                    arr.push(node.data().0.clone());
346                })
347            }
348            arr
349        }
350        fn check_size(&self) -> bool {
351            fn dfs<T>(node: NodeRef<marker::Immut<'_>, SizedSplay<T>>) -> Option<usize> {
352                let mut size = 1usize;
353                if let Some(node) = node.left() {
354                    size += dfs(node)?;
355                }
356                if let Some(node) = node.right() {
357                    size += dfs(node)?;
358                }
359                if size == node.data().1 {
360                    Some(size)
361                } else {
362                    None
363                }
364            }
365            if let Some(root) = self.root.root() {
366                dfs(root) == Some(self.len())
367            } else {
368                true
369            }
370        }
371    }
372
373    #[test]
374    fn test_insert_remove_get() {
375        const Q: usize = 30_000;
376        const A: u64 = 500;
377        let mut stree = SplayMap::new();
378        let mut btree = BTreeMap::new();
379        let mut rng = Xorshift::default();
380        for v in 1..=Q {
381            let k = rng.rand(A);
382            match rng.random(0..3) {
383                0 => assert_eq!(btree.remove(&k), stree.remove(&k)),
384                1 => assert_eq!(btree.insert(k, v), stree.insert(k, v)),
385                _ => assert_eq!(btree.get_key_value(&k), stree.get_key_value(&k)),
386            }
387            assert_eq!(btree.len(), stree.len());
388            assert!(stree.check_size());
389        }
390    }
391
392    #[test]
393    fn test_at() {
394        const Q: usize = 30_000;
395        const A: u64 = 500;
396        let mut stree = SplayMap::new();
397        let mut btree = BTreeMap::new();
398        let mut rng = Xorshift::default();
399        for v in 1..=Q {
400            let k = rng.rand(A);
401            let i = rng.random(0..=btree.len());
402            match rng.random(0..3) {
403                0 => {
404                    if let Some((&k, _)) = btree.iter().nth(i) {
405                        assert_eq!(btree.remove(&k).map(|v| (k, v)), stree.remove_at(i));
406                    }
407                }
408                1 => assert_eq!(btree.insert(k, v), stree.insert(k, v)),
409                _ => assert_eq!(btree.iter().nth(i), stree.get_key_value_at(i)),
410            }
411            assert_eq!(btree.len(), stree.len());
412            assert!(stree.check_size());
413        }
414    }
415
416    #[test]
417    fn test_iter() {
418        const Q: usize = 3_000;
419        const A: u64 = 100;
420        let mut stree = SplayMap::new();
421        let mut btree = BTreeMap::new();
422        let mut rng = Xorshift::default();
423        for v in 1..=Q {
424            for v in v * 100..(v + 1) * 100 {
425                let k = rng.rand(A);
426                match rng.random(0..2) {
427                    0 => assert_eq!(btree.remove(&k), stree.remove(&k)),
428                    _ => assert_eq!(btree.insert(k, v), stree.insert(k, v)),
429                }
430            }
431
432            let b: Vec<_> = btree.iter().map(|(k, v)| (*k, *v)).collect();
433            let a = stree.dump();
434            assert_eq!(b, a);
435
436            match rng.random(0..3) {
437                0 => {
438                    let a: Vec<_> = if rng.random(0..2) == 0 {
439                        stree.iter().collect()
440                    } else {
441                        let mut a: Vec<_> = stree.iter().rev().collect();
442                        a.reverse();
443                        a
444                    };
445                    assert_eq!(b, a);
446                }
447                1 if !stree.is_empty() => {
448                    let (mut l, mut r) = (rng.random(0..=stree.len()), rng.random(0..=stree.len()));
449                    if l > r {
450                        swap(&mut l, &mut r);
451                    }
452                    let l = match rng.random(0..3) {
453                        0 => Bound::Included(l),
454                        1 => Bound::Excluded(l),
455                        _ => Bound::Unbounded,
456                    };
457                    let r = match rng.random(0..3) {
458                        0 => Bound::Included(r),
459                        1 => Bound::Excluded(r),
460                        _ => Bound::Unbounded,
461                    };
462                    let a: Vec<_> = stree.range_at((l, r)).collect();
463                    let lc = match l {
464                        Bound::Included(l) => l,
465                        Bound::Excluded(l) => l + 1,
466                        Bound::Unbounded => 0,
467                    };
468                    let rc = match r {
469                        Bound::Included(r) => r + 1,
470                        Bound::Excluded(r) => r,
471                        Bound::Unbounded => !0,
472                    };
473                    let b: Vec<_> = btree
474                        .iter()
475                        .take(rc)
476                        .skip(lc)
477                        .map(|(k, v)| (*k, *v))
478                        .collect();
479                    assert_eq!(b, a);
480                }
481                _ => {
482                    let (mut l, mut r) = (rng.random(0..=A), rng.random(0..=A));
483                    if l > r {
484                        swap(&mut l, &mut r);
485                    }
486                    let l = match rng.random(0..3) {
487                        0 => Bound::Included(l),
488                        1 => Bound::Excluded(l),
489                        _ => Bound::Unbounded,
490                    };
491                    let r = match rng.random(0..3) {
492                        0 => Bound::Included(r),
493                        1 if l == Bound::Excluded(r) => Bound::Excluded(r + 1),
494                        1 => Bound::Excluded(r),
495                        _ => Bound::Unbounded,
496                    };
497                    let b: Vec<_> = btree.range((l, r)).map(|(k, v)| (*k, *v)).collect();
498                    let a: Vec<_> = stree.range((l, r)).collect();
499                    assert_eq!(b, a);
500                }
501            }
502            assert_eq!(btree.len(), stree.len());
503            assert!(stree.check_size());
504
505            let b: Vec<_> = btree.iter().map(|(k, v)| (*k, *v)).collect();
506            let a = stree.dump();
507            assert_eq!(b, a);
508        }
509    }
510
511    #[test]
512    fn test_drop() {
513        #[derive(Debug)]
514        struct CheckDrop<T>(T);
515        thread_local! {
516            static CNT: RefCell<usize> = const { RefCell::new(0) };
517        }
518        impl<T> Drop for CheckDrop<T> {
519            fn drop(&mut self) {
520                CNT.with(|cnt| *cnt.borrow_mut() += 1);
521            }
522        }
523        const Q: usize = 3_000;
524        const A: u64 = 500;
525        let mut cnt = 0usize;
526        let mut rng = Xorshift::default();
527        for _ in 0..10 {
528            {
529                let mut stree = SplayMap::new();
530                for v in 0..Q {
531                    {
532                        let k = rng.rand(A);
533                        cnt += stree.remove(&k).is_some() as usize;
534                        let k = rng.rand(A);
535                        cnt += stree.insert(k, CheckDrop(v)).is_some() as usize;
536                    }
537                    assert_eq!(cnt, CNT.with(|cnt| *cnt.borrow()));
538                }
539                cnt += stree.len();
540            }
541            assert_eq!(cnt, CNT.with(|cnt| *cnt.borrow()));
542        }
543    }
544}