competitive/algorithm/
sort.rs

1use std::{cmp::Ordering, ptr::copy_nonoverlapping};
2
3pub trait SliceSortExt<T> {
4    fn bubble_sort(&mut self)
5    where
6        T: Ord;
7    fn bubble_sort_by<F>(&mut self, compare: F)
8    where
9        F: FnMut(&T, &T) -> Ordering;
10    fn merge_sort(&mut self)
11    where
12        T: Ord;
13    fn merge_sort_by<F>(&mut self, compare: F)
14    where
15        F: FnMut(&T, &T) -> Ordering;
16    fn insertion_sort(&mut self)
17    where
18        T: Ord;
19    fn insertion_sort_by<F>(&mut self, compare: F)
20    where
21        F: FnMut(&T, &T) -> Ordering;
22}
23impl<T> SliceSortExt<T> for [T] {
24    fn bubble_sort(&mut self)
25    where
26        T: Ord,
27    {
28        bubble_sort(self, |a, b| a.lt(b));
29    }
30    fn bubble_sort_by<F>(&mut self, mut compare: F)
31    where
32        F: FnMut(&T, &T) -> Ordering,
33    {
34        bubble_sort(self, |a, b| compare(a, b) == Ordering::Less);
35    }
36    fn merge_sort(&mut self)
37    where
38        T: Ord,
39    {
40        merge_sort(self, |a, b| a.lt(b));
41    }
42    fn merge_sort_by<F>(&mut self, mut compare: F)
43    where
44        F: FnMut(&T, &T) -> Ordering,
45    {
46        merge_sort(self, |a, b| compare(a, b) == Ordering::Less);
47    }
48    fn insertion_sort(&mut self)
49    where
50        T: Ord,
51    {
52        insertion_sort(self, |a, b| a.lt(b));
53    }
54    fn insertion_sort_by<F>(&mut self, mut compare: F)
55    where
56        F: FnMut(&T, &T) -> Ordering,
57    {
58        insertion_sort(self, |a, b| compare(a, b) == Ordering::Less);
59    }
60}
61
62fn bubble_sort<T, F>(v: &mut [T], mut is_less: F)
63where
64    F: FnMut(&T, &T) -> bool,
65{
66    let len = v.len();
67    if len <= 1 {
68        return;
69    }
70    for i in 0..len - 1 {
71        for j in 0..len - i - 1 {
72            unsafe {
73                if is_less(v.get_unchecked(j + 1), v.get_unchecked(j)) {
74                    v.swap(j, j + 1);
75                }
76            }
77        }
78    }
79}
80
81unsafe fn merge<T, F>(v: &mut [T], mut mid: usize, buf: *mut T, is_less: &mut F)
82where
83    F: FnMut(&T, &T) -> bool,
84{
85    unsafe {
86        let len = v.len();
87        let v = v.as_mut_ptr();
88        let (v_mid, v_end) = (v.add(mid), v.add(len));
89
90        copy_nonoverlapping(v, buf, mid);
91        let mut start = buf;
92        let end = buf.add(mid);
93        let mut dest = v;
94
95        let left = &mut start;
96        let mut right = v_mid;
97        while *left < end && right < v_end {
98            let to_copy = if is_less(&*right, &**left) {
99                get_and_increment(&mut right)
100            } else {
101                mid -= 1;
102                get_and_increment(left)
103            };
104            copy_nonoverlapping(to_copy, get_and_increment(&mut dest), 1);
105        }
106
107        // let len = end.sub_ptr(start);
108        copy_nonoverlapping(start, dest, mid);
109    }
110
111    unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
112        let old = *ptr;
113        *ptr = unsafe { ptr.offset(1) };
114        old
115    }
116}
117
118fn merge_sort<T, F>(v: &mut [T], mut is_less: F)
119where
120    F: FnMut(&T, &T) -> bool,
121{
122    let len = v.len();
123    if len <= 1 {
124        return;
125    }
126    let mut buf = Vec::with_capacity(len / 2);
127    let mut runs: Vec<Run> = vec![];
128    let mut end = len;
129    while end > 0 {
130        let start = end - 1;
131        let mut left = Run {
132            start,
133            len: end - start,
134        };
135        end = start;
136
137        while let Some(&right) = runs.last() {
138            if left.start > 0 && right.len > left.len {
139                break;
140            }
141            runs.pop().unwrap();
142            unsafe {
143                merge(
144                    &mut v[left.start..right.start + right.len],
145                    left.len,
146                    buf.as_mut_ptr(),
147                    &mut is_less,
148                );
149            }
150            left = Run {
151                start: left.start,
152                len: left.len + right.len,
153            };
154        }
155        runs.push(left);
156    }
157
158    debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
159
160    #[derive(Clone, Copy)]
161    struct Run {
162        start: usize,
163        len: usize,
164    }
165}
166
167fn insertion_sort<T, F>(v: &mut [T], mut is_less: F)
168where
169    F: FnMut(&T, &T) -> bool,
170{
171    for i in 1..v.len() {
172        let x = &v[i];
173        let p = v[..i].partition_point(|y| is_less(y, x));
174        v[p..=i].rotate_right(1);
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::{algorithm::SliceCombinationsExt, tools::Xorshift};
182
183    macro_rules! test_sort {
184        (@small $sort_method:ident) => {
185            for n in 0..=8 {
186                let a: Vec<_> = (0..n).collect();
187                a.for_each_permutations(n, |a| {
188                    let mut x = a.to_vec();
189                    let mut y = a.to_vec();
190                    x.sort();
191                    y.$sort_method();
192                    assert_eq!(x, y);
193                });
194            }
195        };
196        (@large $sort_method:ident, $n_ub:expr) => {{
197            let mut rng = Xorshift::default();
198            for _ in 0..10 {
199                let n = rng.random(..$n_ub);
200                let ub = 1 << rng.random(0..20);
201                let a: Vec<_> = rng.random_iter(0..ub).take(n).collect();
202                let mut x = a.to_vec();
203                let mut y = a.to_vec();
204                x.sort();
205                y.$sort_method();
206                assert_eq!(x, y);
207            }
208        }};
209    }
210
211    #[test]
212    fn test_bubble_sort_small() {
213        test_sort!(@small bubble_sort);
214    }
215
216    #[test]
217    fn test_bubble_sort_large() {
218        test_sort!(@large bubble_sort, 3000);
219    }
220
221    #[test]
222    fn test_merge_sort_small() {
223        test_sort!(@small merge_sort);
224    }
225
226    #[test]
227    fn test_merge_sort_large() {
228        test_sort!(@large merge_sort, 100_000);
229    }
230
231    #[test]
232    fn test_insertion_sort_small() {
233        test_sort!(@small insertion_sort);
234    }
235
236    #[test]
237    fn test_insertion_sort_large() {
238        test_sort!(@large insertion_sort, 100_000);
239    }
240}