1use super::{
2 Allocator, MemoryPool,
3 node::{Node, NodeRange, NodeRef, Root, SplaySeeker, SplaySpec, marker},
4};
5use std::{
6 borrow::Borrow,
7 cmp::Ordering,
8 fmt::{self, Debug},
9 iter::FusedIterator,
10 marker::PhantomData,
11 mem::{ManuallyDrop, replace},
12 ops::{Bound, DerefMut, RangeBounds},
13};
14
15struct SizedSplay<T> {
16 _marker: PhantomData<fn() -> T>,
17}
18impl<T> SplaySpec for SizedSplay<T> {
19 type T = (T, usize);
20 fn has_bottom_up() -> bool {
21 true
22 }
23 fn bottom_up(node: NodeRef<marker::DataMut<'_>, Self>) {
24 let l = node.left().map(|p| p.data().1).unwrap_or_default();
25 let r = node.right().map(|p| p.data().1).unwrap_or_default();
26 node.data_mut().1 = l + r + 1;
27 }
28}
29
30struct SeekByKey<'a, K, V, Q>
31where
32 Q: ?Sized,
33{
34 key: &'a Q,
35 _marker: PhantomData<fn() -> (K, V)>,
36}
37impl<'a, K, V, Q> SeekByKey<'a, K, V, Q>
38where
39 Q: ?Sized,
40{
41 fn new(key: &'a Q) -> Self {
42 Self {
43 key,
44 _marker: PhantomData,
45 }
46 }
47}
48impl<K, V, Q> SplaySeeker for SeekByKey<'_, K, V, Q>
49where
50 K: Borrow<Q>,
51 Q: Ord + ?Sized,
52{
53 type S = SizedSplay<(K, V)>;
54 fn splay_seek(&mut self, node: NodeRef<marker::Immut<'_>, Self::S>) -> Ordering {
55 self.key.cmp((node.data().0).0.borrow())
56 }
57}
58
59struct SeekBySize<K, V> {
60 index: usize,
61 _marker: PhantomData<fn() -> (K, V)>,
62}
63impl<K, V> SeekBySize<K, V> {
64 fn new(index: usize) -> Self {
65 Self {
66 index,
67 _marker: PhantomData,
68 }
69 }
70}
71impl<K, V> SplaySeeker for SeekBySize<K, V> {
72 type S = SizedSplay<(K, V)>;
73 fn splay_seek(&mut self, node: NodeRef<marker::Immut<'_>, Self::S>) -> Ordering {
74 let lsize = node.left().map(|l| l.data().1).unwrap_or_default();
75 let ord = self.index.cmp(&lsize);
76 if matches!(ord, Ordering::Greater) {
77 self.index -= lsize + 1;
78 }
79 ord
80 }
81}
82
83pub struct SplayMap<K, V, A = MemoryPool<Node<((K, V), usize)>>>
84where
85 A: Allocator<Node<((K, V), usize)>>,
86{
87 root: Root<SizedSplay<(K, V)>>,
88 length: usize,
89 alloc: ManuallyDrop<A>,
90}
91
92impl<K, V, A> Debug for SplayMap<K, V, A>
93where
94 K: Debug,
95 V: Debug,
96 A: Allocator<Node<((K, V), usize)>>,
97{
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 f.debug_struct("SplayMap")
100 .field("root", &self.root)
101 .field("length", &self.length)
102 .finish()
103 }
104}
105
106impl<K, V, A> Drop for SplayMap<K, V, A>
107where
108 A: Allocator<Node<((K, V), usize)>>,
109{
110 fn drop(&mut self) {
111 unsafe {
112 while let Some(node) = self.root.take_first() {
113 self.alloc.deallocate(node.into_dying().into_inner());
114 }
115 ManuallyDrop::drop(&mut self.alloc);
116 }
117 }
118}
119
120impl<K, V, A> Default for SplayMap<K, V, A>
121where
122 A: Allocator<Node<((K, V), usize)>> + Default,
123{
124 fn default() -> Self {
125 Self {
126 root: Root::default(),
127 length: 0,
128 alloc: Default::default(),
129 }
130 }
131}
132
133impl<K, V> SplayMap<K, V> {
134 pub fn new() -> Self {
135 Default::default()
136 }
137 pub fn with_capacity(capacity: usize) -> Self {
138 Self {
139 root: Root::default(),
140 length: 0,
141 alloc: ManuallyDrop::new(MemoryPool::with_capacity(capacity)),
142 }
143 }
144}
145impl<K, V, A> SplayMap<K, V, A>
146where
147 A: Allocator<Node<((K, V), usize)>>,
148{
149 pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
150 where
151 K: Borrow<Q>,
152 Q: Ord + ?Sized,
153 {
154 self.get_key_value(key).map(|(_, v)| v)
155 }
156 fn splay_by_key<Q>(&mut self, key: &Q) -> Option<Ordering>
157 where
158 K: Borrow<Q>,
159 Q: Ord + ?Sized,
160 {
161 self.root.splay_by(SeekByKey::new(key))
162 }
163 pub fn get_key_value<Q>(&mut self, key: &Q) -> Option<(&K, &V)>
164 where
165 K: Borrow<Q>,
166 Q: Ord + ?Sized,
167 {
168 if !matches!(self.splay_by_key(key)?, Ordering::Equal) {
169 return None;
170 }
171 self.root.root().map(|node| {
172 let ((k, v), _) = node.data();
173 (k, v)
174 })
175 }
176 fn splay_at(&mut self, index: usize) -> Option<Ordering> {
177 self.root.splay_by(SeekBySize::new(index))
178 }
179 pub fn get_key_value_at(&mut self, index: usize) -> Option<(&K, &V)> {
180 if index >= self.length {
181 return None;
182 }
183 self.splay_at(index);
184 self.root.root().map(|node| {
185 let ((k, v), _) = node.data();
186 (k, v)
187 })
188 }
189 pub fn insert(&mut self, key: K, value: V) -> Option<V>
190 where
191 K: Ord,
192 {
193 let ord = self.splay_by_key(&key);
194 self.length += (ord != Some(Ordering::Equal)) as usize;
195 match ord {
196 Some(Ordering::Equal) => {
197 return Some(replace(
198 &mut (self.root.root_data_mut().unwrap().data_mut().0).1,
199 value,
200 ));
201 }
202 Some(Ordering::Less) => unsafe {
203 self.root.insert_left(NodeRef::from_data(
204 ((key, value), 1),
205 self.alloc.deref_mut(),
206 ));
207 },
208 _ => unsafe {
209 self.root.insert_right(NodeRef::from_data(
210 ((key, value), 1),
211 self.alloc.deref_mut(),
212 ));
213 },
214 }
215 None
216 }
217 pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
218 where
219 K: Borrow<Q>,
220 Q: Ord + ?Sized,
221 {
222 if !matches!(self.splay_by_key(key)?, Ordering::Equal) {
223 return None;
224 }
225 self.length -= 1;
226 let node = self.root.take_root().unwrap().into_dying();
227 unsafe { Some((node.into_data(self.alloc.deref_mut()).0).1) }
228 }
229 pub fn remove_at(&mut self, index: usize) -> Option<(K, V)> {
230 if index >= self.length {
231 return None;
232 }
233 self.splay_at(index);
234 self.length -= 1;
235 let node = self.root.take_root().unwrap().into_dying();
236 unsafe { Some(node.into_data(self.alloc.deref_mut()).0) }
237 }
238 pub fn len(&self) -> usize {
239 self.length
240 }
241 pub fn is_empty(&self) -> bool {
242 self.len() == 0
243 }
244 pub fn iter(&mut self) -> Iter<'_, K, V> {
245 Iter {
246 iter: NodeRange::new(&mut self.root),
247 }
248 }
249 pub fn range<Q, R>(&mut self, range: R) -> Iter<'_, K, V>
250 where
251 K: Borrow<Q>,
252 Q: Ord + ?Sized,
253 R: RangeBounds<Q>,
254 {
255 let start = match range.start_bound() {
256 Bound::Included(key) => Bound::Included(SeekByKey::new(key)),
257 Bound::Excluded(key) => Bound::Excluded(SeekByKey::new(key)),
258 Bound::Unbounded => Bound::Unbounded,
259 };
260 let end = match range.end_bound() {
261 Bound::Included(key) => Bound::Included(SeekByKey::new(key)),
262 Bound::Excluded(key) => Bound::Excluded(SeekByKey::new(key)),
263 Bound::Unbounded => Bound::Unbounded,
264 };
265 Iter {
266 iter: self.root.range(start, end),
267 }
268 }
269 pub fn range_at<R>(&mut self, range: R) -> Iter<'_, K, V>
270 where
271 R: RangeBounds<usize>,
272 {
273 let start = match range.start_bound() {
274 Bound::Included(&index) => Bound::Included(SeekBySize::new(index)),
275 Bound::Excluded(&index) => Bound::Excluded(SeekBySize::new(index)),
276 Bound::Unbounded => Bound::Unbounded,
277 };
278 let end = match range.end_bound() {
279 Bound::Included(&index) => Bound::Included(SeekBySize::new(index)),
280 Bound::Excluded(&index) => Bound::Excluded(SeekBySize::new(index)),
281 Bound::Unbounded => Bound::Unbounded,
282 };
283 Iter {
284 iter: self.root.range(start, end),
285 }
286 }
287}
288
289#[derive(Debug)]
290pub struct Iter<'a, K, V> {
291 iter: NodeRange<'a, SizedSplay<(K, V)>>,
292}
293impl<K, V> Iterator for Iter<'_, K, V>
294where
295 K: Clone,
296 V: Clone,
297{
298 type Item = (K, V);
299 fn next(&mut self) -> Option<Self::Item> {
300 self.iter.next_checked().map(|node| node.data().0.clone())
301 }
302 fn last(mut self) -> Option<Self::Item> {
303 self.next_back()
304 }
305 fn min(mut self) -> Option<Self::Item> {
306 self.next()
307 }
308 fn max(mut self) -> Option<Self::Item> {
309 self.next_back()
310 }
311}
312impl<K, V> FusedIterator for Iter<'_, K, V>
313where
314 K: Clone,
315 V: Clone,
316{
317}
318impl<K, V> DoubleEndedIterator for Iter<'_, K, V>
319where
320 K: Clone,
321 V: Clone,
322{
323 fn next_back(&mut self) -> Option<Self::Item> {
324 self.iter
325 .next_back_checked()
326 .map(|node| node.data().0.clone())
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::tools::Xorshift;
334 use std::{cell::RefCell, collections::BTreeMap, mem::swap};
335
336 impl<K, V> SplayMap<K, V> {
337 fn dump(&self) -> Vec<(K, V)>
338 where
339 K: Clone,
340 V: Clone,
341 {
342 let mut arr = vec![];
343 if let Some(root) = self.root.root() {
344 root.traverse(&mut |node| {
345 arr.push(node.data().0.clone());
346 })
347 }
348 arr
349 }
350 fn check_size(&self) -> bool {
351 fn dfs<T>(node: NodeRef<marker::Immut<'_>, SizedSplay<T>>) -> Option<usize> {
352 let mut size = 1usize;
353 if let Some(node) = node.left() {
354 size += dfs(node)?;
355 }
356 if let Some(node) = node.right() {
357 size += dfs(node)?;
358 }
359 if size == node.data().1 {
360 Some(size)
361 } else {
362 None
363 }
364 }
365 if let Some(root) = self.root.root() {
366 dfs(root) == Some(self.len())
367 } else {
368 true
369 }
370 }
371 }
372
373 #[test]
374 fn test_insert_remove_get() {
375 const Q: usize = 30_000;
376 const A: u64 = 500;
377 let mut stree = SplayMap::new();
378 let mut btree = BTreeMap::new();
379 let mut rng = Xorshift::default();
380 for v in 1..=Q {
381 let k = rng.rand(A);
382 match rng.random(0..3) {
383 0 => assert_eq!(btree.remove(&k), stree.remove(&k)),
384 1 => assert_eq!(btree.insert(k, v), stree.insert(k, v)),
385 _ => assert_eq!(btree.get_key_value(&k), stree.get_key_value(&k)),
386 }
387 assert_eq!(btree.len(), stree.len());
388 assert!(stree.check_size());
389 }
390 }
391
392 #[test]
393 fn test_at() {
394 const Q: usize = 30_000;
395 const A: u64 = 500;
396 let mut stree = SplayMap::new();
397 let mut btree = BTreeMap::new();
398 let mut rng = Xorshift::default();
399 for v in 1..=Q {
400 let k = rng.rand(A);
401 let i = rng.random(0..=btree.len());
402 match rng.random(0..3) {
403 0 => {
404 if let Some((&k, _)) = btree.iter().nth(i) {
405 assert_eq!(btree.remove(&k).map(|v| (k, v)), stree.remove_at(i));
406 }
407 }
408 1 => assert_eq!(btree.insert(k, v), stree.insert(k, v)),
409 _ => assert_eq!(btree.iter().nth(i), stree.get_key_value_at(i)),
410 }
411 assert_eq!(btree.len(), stree.len());
412 assert!(stree.check_size());
413 }
414 }
415
416 #[test]
417 fn test_iter() {
418 const Q: usize = 3_000;
419 const A: u64 = 100;
420 let mut stree = SplayMap::new();
421 let mut btree = BTreeMap::new();
422 let mut rng = Xorshift::default();
423 for v in 1..=Q {
424 for v in v * 100..(v + 1) * 100 {
425 let k = rng.rand(A);
426 match rng.random(0..2) {
427 0 => assert_eq!(btree.remove(&k), stree.remove(&k)),
428 _ => assert_eq!(btree.insert(k, v), stree.insert(k, v)),
429 }
430 }
431
432 let b: Vec<_> = btree.iter().map(|(k, v)| (*k, *v)).collect();
433 let a = stree.dump();
434 assert_eq!(b, a);
435
436 match rng.random(0..3) {
437 0 => {
438 let a: Vec<_> = if rng.random(0..2) == 0 {
439 stree.iter().collect()
440 } else {
441 let mut a: Vec<_> = stree.iter().rev().collect();
442 a.reverse();
443 a
444 };
445 assert_eq!(b, a);
446 }
447 1 if !stree.is_empty() => {
448 let (mut l, mut r) = (rng.random(0..=stree.len()), rng.random(0..=stree.len()));
449 if l > r {
450 swap(&mut l, &mut r);
451 }
452 let l = match rng.random(0..3) {
453 0 => Bound::Included(l),
454 1 => Bound::Excluded(l),
455 _ => Bound::Unbounded,
456 };
457 let r = match rng.random(0..3) {
458 0 => Bound::Included(r),
459 1 => Bound::Excluded(r),
460 _ => Bound::Unbounded,
461 };
462 let a: Vec<_> = stree.range_at((l, r)).collect();
463 let lc = match l {
464 Bound::Included(l) => l,
465 Bound::Excluded(l) => l + 1,
466 Bound::Unbounded => 0,
467 };
468 let rc = match r {
469 Bound::Included(r) => r + 1,
470 Bound::Excluded(r) => r,
471 Bound::Unbounded => !0,
472 };
473 let b: Vec<_> = btree
474 .iter()
475 .take(rc)
476 .skip(lc)
477 .map(|(k, v)| (*k, *v))
478 .collect();
479 assert_eq!(b, a);
480 }
481 _ => {
482 let (mut l, mut r) = (rng.random(0..=A), rng.random(0..=A));
483 if l > r {
484 swap(&mut l, &mut r);
485 }
486 let l = match rng.random(0..3) {
487 0 => Bound::Included(l),
488 1 => Bound::Excluded(l),
489 _ => Bound::Unbounded,
490 };
491 let r = match rng.random(0..3) {
492 0 => Bound::Included(r),
493 1 if l == Bound::Excluded(r) => Bound::Excluded(r + 1),
494 1 => Bound::Excluded(r),
495 _ => Bound::Unbounded,
496 };
497 let b: Vec<_> = btree.range((l, r)).map(|(k, v)| (*k, *v)).collect();
498 let a: Vec<_> = stree.range((l, r)).collect();
499 assert_eq!(b, a);
500 }
501 }
502 assert_eq!(btree.len(), stree.len());
503 assert!(stree.check_size());
504
505 let b: Vec<_> = btree.iter().map(|(k, v)| (*k, *v)).collect();
506 let a = stree.dump();
507 assert_eq!(b, a);
508 }
509 }
510
511 #[test]
512 fn test_drop() {
513 #[derive(Debug)]
514 struct CheckDrop<T>(T);
515 thread_local! {
516 static CNT: RefCell<usize> = const { RefCell::new(0) };
517 }
518 impl<T> Drop for CheckDrop<T> {
519 fn drop(&mut self) {
520 CNT.with(|cnt| *cnt.borrow_mut() += 1);
521 }
522 }
523 const Q: usize = 3_000;
524 const A: u64 = 500;
525 let mut cnt = 0usize;
526 let mut rng = Xorshift::default();
527 for _ in 0..10 {
528 {
529 let mut stree = SplayMap::new();
530 for v in 0..Q {
531 {
532 let k = rng.rand(A);
533 cnt += stree.remove(&k).is_some() as usize;
534 let k = rng.rand(A);
535 cnt += stree.insert(k, CheckDrop(v)).is_some() as usize;
536 }
537 assert_eq!(cnt, CNT.with(|cnt| *cnt.borrow()));
538 }
539 cnt += stree.len();
540 }
541 assert_eq!(cnt, CNT.with(|cnt| *cnt.borrow()));
542 }
543 }
544}