competitive/algorithm/
sort.rs1use std::{cmp::Ordering, ptr::copy_nonoverlapping};
2
3pub trait SliceSortExt<T> {
4 fn bubble_sort(&mut self)
5 where
6 T: Ord;
7 fn bubble_sort_by<F>(&mut self, compare: F)
8 where
9 F: FnMut(&T, &T) -> Ordering;
10 fn merge_sort(&mut self)
11 where
12 T: Ord;
13 fn merge_sort_by<F>(&mut self, compare: F)
14 where
15 F: FnMut(&T, &T) -> Ordering;
16 fn insertion_sort(&mut self)
17 where
18 T: Ord;
19 fn insertion_sort_by<F>(&mut self, compare: F)
20 where
21 F: FnMut(&T, &T) -> Ordering;
22}
23impl<T> SliceSortExt<T> for [T] {
24 fn bubble_sort(&mut self)
25 where
26 T: Ord,
27 {
28 bubble_sort(self, |a, b| a.lt(b));
29 }
30 fn bubble_sort_by<F>(&mut self, mut compare: F)
31 where
32 F: FnMut(&T, &T) -> Ordering,
33 {
34 bubble_sort(self, |a, b| compare(a, b) == Ordering::Less);
35 }
36 fn merge_sort(&mut self)
37 where
38 T: Ord,
39 {
40 merge_sort(self, |a, b| a.lt(b));
41 }
42 fn merge_sort_by<F>(&mut self, mut compare: F)
43 where
44 F: FnMut(&T, &T) -> Ordering,
45 {
46 merge_sort(self, |a, b| compare(a, b) == Ordering::Less);
47 }
48 fn insertion_sort(&mut self)
49 where
50 T: Ord,
51 {
52 insertion_sort(self, |a, b| a.lt(b));
53 }
54 fn insertion_sort_by<F>(&mut self, mut compare: F)
55 where
56 F: FnMut(&T, &T) -> Ordering,
57 {
58 insertion_sort(self, |a, b| compare(a, b) == Ordering::Less);
59 }
60}
61
62fn bubble_sort<T, F>(v: &mut [T], mut is_less: F)
63where
64 F: FnMut(&T, &T) -> bool,
65{
66 let len = v.len();
67 if len <= 1 {
68 return;
69 }
70 for i in 0..len - 1 {
71 for j in 0..len - i - 1 {
72 unsafe {
73 if is_less(v.get_unchecked(j + 1), v.get_unchecked(j)) {
74 v.swap(j, j + 1);
75 }
76 }
77 }
78 }
79}
80
81unsafe fn merge<T, F>(v: &mut [T], mut mid: usize, buf: *mut T, is_less: &mut F)
82where
83 F: FnMut(&T, &T) -> bool,
84{
85 unsafe {
86 let len = v.len();
87 let v = v.as_mut_ptr();
88 let (v_mid, v_end) = (v.add(mid), v.add(len));
89
90 copy_nonoverlapping(v, buf, mid);
91 let mut start = buf;
92 let end = buf.add(mid);
93 let mut dest = v;
94
95 let left = &mut start;
96 let mut right = v_mid;
97 while *left < end && right < v_end {
98 let to_copy = if is_less(&*right, &**left) {
99 get_and_increment(&mut right)
100 } else {
101 mid -= 1;
102 get_and_increment(left)
103 };
104 copy_nonoverlapping(to_copy, get_and_increment(&mut dest), 1);
105 }
106
107 copy_nonoverlapping(start, dest, mid);
109 }
110
111 unsafe fn get_and_increment<T>(ptr: &mut *mut T) -> *mut T {
112 let old = *ptr;
113 *ptr = unsafe { ptr.offset(1) };
114 old
115 }
116}
117
118fn merge_sort<T, F>(v: &mut [T], mut is_less: F)
119where
120 F: FnMut(&T, &T) -> bool,
121{
122 let len = v.len();
123 if len <= 1 {
124 return;
125 }
126 let mut buf = Vec::with_capacity(len / 2);
127 let mut runs: Vec<Run> = vec![];
128 let mut end = len;
129 while end > 0 {
130 let start = end - 1;
131 let mut left = Run {
132 start,
133 len: end - start,
134 };
135 end = start;
136
137 while let Some(&right) = runs.last() {
138 if left.start > 0 && right.len > left.len {
139 break;
140 }
141 runs.pop().unwrap();
142 unsafe {
143 merge(
144 &mut v[left.start..right.start + right.len],
145 left.len,
146 buf.as_mut_ptr(),
147 &mut is_less,
148 );
149 }
150 left = Run {
151 start: left.start,
152 len: left.len + right.len,
153 };
154 }
155 runs.push(left);
156 }
157
158 debug_assert!(runs.len() == 1 && runs[0].start == 0 && runs[0].len == len);
159
160 #[derive(Clone, Copy)]
161 struct Run {
162 start: usize,
163 len: usize,
164 }
165}
166
167fn insertion_sort<T, F>(v: &mut [T], mut is_less: F)
168where
169 F: FnMut(&T, &T) -> bool,
170{
171 for i in 1..v.len() {
172 let x = &v[i];
173 let p = v[..i].partition_point(|y| is_less(y, x));
174 v[p..=i].rotate_right(1);
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::{algorithm::SliceCombinationsExt, tools::Xorshift};
182
183 macro_rules! test_sort {
184 (@small $sort_method:ident) => {
185 for n in 0..=8 {
186 let a: Vec<_> = (0..n).collect();
187 a.for_each_permutations(n, |a| {
188 let mut x = a.to_vec();
189 let mut y = a.to_vec();
190 x.sort();
191 y.$sort_method();
192 assert_eq!(x, y);
193 });
194 }
195 };
196 (@large $sort_method:ident, $n_ub:expr) => {{
197 let mut rng = Xorshift::default();
198 for _ in 0..10 {
199 let n = rng.random(..$n_ub);
200 let ub = 1 << rng.random(0..20);
201 let a: Vec<_> = rng.random_iter(0..ub).take(n).collect();
202 let mut x = a.to_vec();
203 let mut y = a.to_vec();
204 x.sort();
205 y.$sort_method();
206 assert_eq!(x, y);
207 }
208 }};
209 }
210
211 #[test]
212 fn test_bubble_sort_small() {
213 test_sort!(@small bubble_sort);
214 }
215
216 #[test]
217 fn test_bubble_sort_large() {
218 test_sort!(@large bubble_sort, 3000);
219 }
220
221 #[test]
222 fn test_merge_sort_small() {
223 test_sort!(@small merge_sort);
224 }
225
226 #[test]
227 fn test_merge_sort_large() {
228 test_sort!(@large merge_sort, 100_000);
229 }
230
231 #[test]
232 fn test_insertion_sort_small() {
233 test_sort!(@small insertion_sort);
234 }
235
236 #[test]
237 fn test_insertion_sort_large() {
238 test_sort!(@large insertion_sort, 100_000);
239 }
240}